feat: realtime

(cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425)
This commit is contained in:
1808837298@qq.com
2024-10-06 14:13:41 +08:00
committed by CalciumIon
parent 33af069fae
commit 74f9006b40
15 changed files with 227 additions and 62 deletions

View File

@@ -47,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)

View File

@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"log"
"net/http"
"one-api/common"
"one-api/constant"
@@ -232,7 +233,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
@@ -325,7 +326,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
@@ -387,6 +388,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
errChan := make(chan error, 2)
usage := &dto.RealtimeUsage{}
localUsage := &dto.RealtimeUsage{}
go func() {
for {
@@ -403,6 +405,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
return
}
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
if realtimeEvent.Session != nil {
if realtimeEvent.Session.Tools != nil {
info.RealtimeTools = realtimeEvent.Session.Tools
}
}
}
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = service.WssString(c, targetConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err)
@@ -451,6 +479,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
}
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
if realtimeSession != nil {
// update audio format
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
localUsage.TotalTokens += textToken + audioToken
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
} else {
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
}
err = service.WssString(c, clientConn, string(message))
@@ -475,5 +529,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
case <-c.Done():
}
// check usage total tokens, if 0, use local usage
if usage.TotalTokens == 0 {
usage = localUsage
}
return nil, usage
}