feat: realtime

(cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425)
This commit is contained in:
1808837298@qq.com
2024-10-06 14:13:41 +08:00
committed by CalciumIon
parent 33af069fae
commit 74f9006b40
15 changed files with 227 additions and 62 deletions

View File

@@ -509,7 +509,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
}, nil
}
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName)
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}

View File

@@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage

View File

@@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
}
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return nil, usage

View File

@@ -47,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)

View File

@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"log"
"net/http"
"one-api/common"
"one-api/constant"
@@ -232,7 +233,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
@@ -325,7 +326,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
@@ -387,6 +388,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
errChan := make(chan error, 2)
usage := &dto.RealtimeUsage{}
localUsage := &dto.RealtimeUsage{}
go func() {
for {
@@ -403,6 +405,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
return
}
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
if realtimeEvent.Session != nil {
if realtimeEvent.Session.Tools != nil {
info.RealtimeTools = realtimeEvent.Session.Tools
}
}
}
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = service.WssString(c, targetConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err)
@@ -451,6 +479,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
}
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
if realtimeSession != nil {
// update audio format
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
localUsage.TotalTokens += textToken + audioToken
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
} else {
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
}
err = service.WssString(c, clientConn, string(message))
@@ -475,5 +529,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
case <-c.Done():
}
// check usage total tokens, if 0, use local usage
if usage.TotalTokens == 0 {
usage = localUsage
}
return nil, usage
}

View File

@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,

View File

@@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"one-api/common"
"one-api/dto"
"one-api/relay/constant"
"strings"
"time"
@@ -35,11 +36,18 @@ type RelayInfo struct {
ShouldIncludeUsage bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn
InputAudioFormat string
OutputAudioFormat string
RealtimeTools []dto.RealTimeTool
IsFirstRequest bool
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c)
info.ClientWs = ws
info.InputAudioFormat = "pcm16"
info.OutputAudioFormat = "pcm16"
info.IsFirstRequest = true
return info
}

View File

@@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
}

View File

@@ -150,7 +150,7 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
quota := 0
if !usePrice {
quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*completionRatio*audioCompletionRatio))
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
@@ -215,16 +215,16 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
//}
}
func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int
var err error
switch info.RelayMode {
default:
promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
}
info.PromptTokens = promptTokens
return promptTokens, err
}
//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
// var promptTokens int
// var err error
// switch info.RelayMode {
// default:
// promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
// }
// info.PromptTokens = promptTokens
// return promptTokens, err
//}
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
// var err error