feat: realtime

(cherry picked from commit a5529df3e1a4c08a120e8c05203a7d885b0fe8d8)
This commit is contained in:
1808837298@qq.com
2024-10-04 16:08:18 +08:00
committed by CalciumIon
parent e3c85572d4
commit 33af069fae
37 changed files with 759 additions and 156 deletions

View File

@@ -2,6 +2,7 @@ package service
import (
"github.com/gin-gonic/gin"
"one-api/dto"
relaycommon "one-api/relay/common"
)
@@ -17,3 +18,13 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
other["admin_info"] = adminInfo
return other
}
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
info["ws"] = true
info["audio_input"] = usage.InputTokenDetails.AudioTokens
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
info["text_input"] = usage.InputTokenDetails.TextTokens
info["text_output"] = usage.OutputTokenDetails.TextTokens
return info
}

View File

@@ -43,11 +43,25 @@ func Done(c *gin.Context) {
_ = StringData(c, "[DONE]")
}
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
if ws == nil {
common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
return ws.WriteMessage(1, []byte(str))
}
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
if ws == nil {
common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
return ws.WriteMessage(1, jsonData)
}

View File

@@ -191,6 +191,45 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
return tkm, nil
}
func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) {
tkm := 0
ratio := 1
if request.Session != nil {
msgTokens, err := CountTokenText(request.Session.Instructions, model)
if err != nil {
return 0, err
}
ratio = len(request.Session.Modalities)
tkm += msgTokens
if request.Session.Tools != nil {
toolsData, _ := json.Marshal(request.Session.Tools)
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
}
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
if tool.Function.Description != "" {
countStr += tool.Function.Description
}
if tool.Function.Parameters != nil {
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens, err := CountTokenInput(countStr, model)
if err != nil {
return 0, err
}
tkm += 8
tkm += toolTokens
}
}
tkm *= ratio
return tkm, nil
}
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)