feat: realtime

(cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425)
This commit is contained in:
1808837298@qq.com
2024-10-06 14:13:41 +08:00
committed by CalciumIon
parent 33af069fae
commit 74f9006b40
15 changed files with 227 additions and 62 deletions

View File

@@ -11,6 +11,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"strings"
"unicode/utf8"
)
@@ -191,43 +192,55 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
return tkm, nil
}
func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) {
tkm := 0
ratio := 1
if request.Session != nil {
msgTokens, err := CountTokenText(request.Session.Instructions, model)
if err != nil {
return 0, err
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
audioToken := 0
textToken := 0
switch request.Type {
case dto.RealtimeEventTypeSessionUpdate:
if request.Session != nil {
msgTokens, err := CountTextToken(request.Session.Instructions, model)
if err != nil {
return 0, 0, err
}
textToken += msgTokens
}
ratio = len(request.Session.Modalities)
tkm += msgTokens
if request.Session.Tools != nil {
toolsData, _ := json.Marshal(request.Session.Tools)
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
}
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
if tool.Function.Description != "" {
countStr += tool.Function.Description
}
if tool.Function.Parameters != nil {
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
case dto.RealtimeEventResponseAudioDelta:
// count audio token
atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
// count text token
tkm, err := CountTextToken(request.Delta, model)
if err != nil {
return 0, 0, fmt.Errorf("error counting text token: %v", err)
}
textToken += tkm
case dto.RealtimeEventInputAudioBufferAppend:
// count audio token
atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventTypeResponseDone:
// count tools token
if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools {
toolTokens, err := CountTokenInput(tool, model)
if err != nil {
return 0, 0, err
}
textToken += 8
textToken += toolTokens
}
}
toolTokens, err := CountTokenInput(countStr, model)
if err != nil {
return 0, err
}
tkm += 8
tkm += toolTokens
}
}
tkm *= ratio
return tkm, nil
return textToken, audioToken, nil
}
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
@@ -287,13 +300,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
func CountTokenInput(input any, model string) (int, error) {
switch v := input.(type) {
case string:
return CountTokenText(v, model)
return CountTextToken(v, model)
case []string:
text := ""
for _, s := range v {
text += s
}
return CountTokenText(text, model)
return CountTextToken(text, model)
}
return CountTokenInput(fmt.Sprintf("%v", input), model)
}
@@ -315,16 +328,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens
}
func CountAudioToken(text string, model string) (int, error) {
func CountTTSToken(text string, model string) (int, error) {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text), nil
} else {
return CountTokenText(text, model)
return CountTextToken(text, model)
}
}
// CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTokenText(text string, model string) (int, error) {
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 100 / 0.06), nil
}
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 200 / 0.24), nil
}
//func CountAudioToken(sec float64, audioType string) {
// if audioType == "input" {
//
// }
//}
// CountTextToken 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTextToken(text string, model string) (int, error) {
var err error
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err