feat: realtime

(cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425)
This commit is contained in:
1808837298@qq.com
2024-10-06 14:13:41 +08:00
committed by CalciumIon
parent 33af069fae
commit 74f9006b40
15 changed files with 227 additions and 62 deletions

31
service/audio.go Normal file
View File

@@ -0,0 +1,31 @@
package service
import (
"encoding/base64"
"fmt"
)
func parseAudio(audioBase64 string, format string) (duration float64, err error) {
audioData, err := base64.StdEncoding.DecodeString(audioBase64)
if err != nil {
return 0, fmt.Errorf("base64 decode error: %v", err)
}
var samplesCount int
var sampleRate int
switch format {
case "pcm16":
samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
sampleRate = 24000 // 24kHz
case "g711_ulaw", "g711_alaw":
samplesCount = len(audioData) // 8位 = 1字节每样本
sampleRate = 8000 // 8kHz
default:
samplesCount = len(audioData) // 8位 = 1字节每样本
sampleRate = 8000 // 8kHz
}
duration = float64(samplesCount) / float64(sampleRate)
return duration, nil
}

View File

@@ -48,7 +48,7 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
return ws.WriteMessage(1, []byte(str))
}
@@ -61,7 +61,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
return ws.WriteMessage(1, jsonData)
}

View File

@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"strings"
"unicode/utf8"
)
@@ -191,43 +192,55 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
return tkm, nil
}
func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) {
tkm := 0
ratio := 1
if request.Session != nil {
msgTokens, err := CountTokenText(request.Session.Instructions, model)
if err != nil {
return 0, err
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
audioToken := 0
textToken := 0
switch request.Type {
case dto.RealtimeEventTypeSessionUpdate:
if request.Session != nil {
msgTokens, err := CountTextToken(request.Session.Instructions, model)
if err != nil {
return 0, 0, err
}
textToken += msgTokens
}
ratio = len(request.Session.Modalities)
tkm += msgTokens
if request.Session.Tools != nil {
toolsData, _ := json.Marshal(request.Session.Tools)
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
}
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
if tool.Function.Description != "" {
countStr += tool.Function.Description
}
if tool.Function.Parameters != nil {
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
case dto.RealtimeEventResponseAudioDelta:
// count audio token
atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
// count text token
tkm, err := CountTextToken(request.Delta, model)
if err != nil {
return 0, 0, fmt.Errorf("error counting text token: %v", err)
}
textToken += tkm
case dto.RealtimeEventInputAudioBufferAppend:
// count audio token
atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventTypeResponseDone:
// count tools token
if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools {
toolTokens, err := CountTokenInput(tool, model)
if err != nil {
return 0, 0, err
}
textToken += 8
textToken += toolTokens
}
}
toolTokens, err := CountTokenInput(countStr, model)
if err != nil {
return 0, err
}
tkm += 8
tkm += toolTokens
}
}
tkm *= ratio
return tkm, nil
return textToken, audioToken, nil
}
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
@@ -287,13 +300,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
func CountTokenInput(input any, model string) (int, error) {
switch v := input.(type) {
case string:
return CountTokenText(v, model)
return CountTextToken(v, model)
case []string:
text := ""
for _, s := range v {
text += s
}
return CountTokenText(text, model)
return CountTextToken(text, model)
}
return CountTokenInput(fmt.Sprintf("%v", input), model)
}
@@ -315,16 +328,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens
}
func CountAudioToken(text string, model string) (int, error) {
func CountTTSToken(text string, model string) (int, error) {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text), nil
} else {
return CountTokenText(text, model)
return CountTextToken(text, model)
}
}
// CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTokenText(text string, model string) (int, error) {
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 100 / 0.06), nil
}
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 200 / 0.24), nil
}
//func CountAudioToken(sec float64, audioType string) {
// if audioType == "input" {
//
// }
//}
// CountTextToken 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTextToken(text string, model string) (int, error) {
var err error
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err

View File

@@ -19,7 +19,7 @@ import (
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
ctkm, err := CountTokenText(responseText, modeName)
ctkm, err := CountTextToken(responseText, modeName)
usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err