mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
feat: realtime
(cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425)
This commit is contained in:
parent
33af069fae
commit
74f9006b40
@ -432,9 +432,23 @@ func GetAudioCompletionRatio(name string) float64 {
|
|||||||
if strings.HasPrefix(name, "gpt-4o-realtime") {
|
if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||||
return 10
|
return 10
|
||||||
}
|
}
|
||||||
return 10
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//func GetAudioPricePerMinute(name string) float64 {
|
||||||
|
// if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||||
|
// return 0.06
|
||||||
|
// }
|
||||||
|
// return 0.06
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func GetAudioCompletionPricePerMinute(name string) float64 {
|
||||||
|
// if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||||
|
// return 0.24
|
||||||
|
// }
|
||||||
|
// return 0.24
|
||||||
|
//}
|
||||||
|
|
||||||
func GetCompletionRatioMap() map[string]float64 {
|
func GetCompletionRatioMap() map[string]float64 {
|
||||||
if CompletionRatio == nil {
|
if CompletionRatio == nil {
|
||||||
CompletionRatio = defaultCompletionRatio
|
CompletionRatio = defaultCompletionRatio
|
||||||
|
@ -5,10 +5,18 @@ const (
|
|||||||
RealtimeEventTypeSessionUpdate = "session.update"
|
RealtimeEventTypeSessionUpdate = "session.update"
|
||||||
RealtimeEventTypeConversationCreate = "conversation.item.create"
|
RealtimeEventTypeConversationCreate = "conversation.item.create"
|
||||||
RealtimeEventTypeResponseCreate = "response.create"
|
RealtimeEventTypeResponseCreate = "response.create"
|
||||||
|
RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RealtimeEventTypeResponseDone = "response.done"
|
RealtimeEventTypeResponseDone = "response.done"
|
||||||
|
RealtimeEventTypeSessionUpdated = "session.updated"
|
||||||
|
RealtimeEventTypeSessionCreated = "session.created"
|
||||||
|
RealtimeEventResponseAudioDelta = "response.audio.delta"
|
||||||
|
RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta"
|
||||||
|
RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
|
||||||
|
RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done"
|
||||||
|
RealtimeEventConversationItemCreated = "conversation.item.created"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RealtimeEvent struct {
|
type RealtimeEvent struct {
|
||||||
@ -19,6 +27,8 @@ type RealtimeEvent struct {
|
|||||||
Item *RealtimeItem `json:"item,omitempty"`
|
Item *RealtimeItem `json:"item,omitempty"`
|
||||||
Error *OpenAIError `json:"error,omitempty"`
|
Error *OpenAIError `json:"error,omitempty"`
|
||||||
Response *RealtimeResponse `json:"response,omitempty"`
|
Response *RealtimeResponse `json:"response,omitempty"`
|
||||||
|
Delta string `json:"delta,omitempty"`
|
||||||
|
Audio string `json:"audio,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RealtimeResponse struct {
|
type RealtimeResponse struct {
|
||||||
|
@ -509,7 +509,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName)
|
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
|||||||
|
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
|
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
|
@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
}
|
}
|
||||||
if usage.TotalTokens == 0 {
|
if usage.TotalTokens == 0 {
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
|
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
|
@ -47,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
model_ := info.UpstreamModelName
|
model_ := info.UpstreamModelName
|
||||||
model_ = strings.Replace(model_, ".", "", -1)
|
model_ = strings.Replace(model_, ".", "", -1)
|
||||||
// https://github.com/songquanpeng/one-api/issues/67
|
// https://github.com/songquanpeng/one-api/issues/67
|
||||||
|
|
||||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||||
|
if info.RelayMode == constant.RelayModeRealtime {
|
||||||
|
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
|
||||||
|
}
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||||
case common.ChannelTypeMiniMax:
|
case common.ChannelTypeMiniMax:
|
||||||
return minimax.GetRequestURL(info)
|
return minimax.GetRequestURL(info)
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@ -232,7 +233,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
|||||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range simpleResponse.Choices {
|
for _, choice := range simpleResponse.Choices {
|
||||||
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
|
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
|
||||||
completionTokens += ctkm
|
completionTokens += ctkm
|
||||||
}
|
}
|
||||||
simpleResponse.Usage = dto.Usage{
|
simpleResponse.Usage = dto.Usage{
|
||||||
@ -325,7 +326,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
|
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
|
usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
@ -387,6 +388,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
usage := &dto.RealtimeUsage{}
|
usage := &dto.RealtimeUsage{}
|
||||||
|
localUsage := &dto.RealtimeUsage{}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
@ -403,6 +405,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
realtimeEvent := &dto.RealtimeEvent{}
|
||||||
|
err = json.Unmarshal(message, realtimeEvent)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
||||||
|
if realtimeEvent.Session != nil {
|
||||||
|
if realtimeEvent.Session.Tools != nil {
|
||||||
|
info.RealtimeTools = realtimeEvent.Session.Tools
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
|
||||||
|
localUsage.TotalTokens += textToken + audioToken
|
||||||
|
localUsage.InputTokens += textToken
|
||||||
|
localUsage.InputTokenDetails.TextTokens += textToken
|
||||||
|
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||||
|
|
||||||
err = service.WssString(c, targetConn, string(message))
|
err = service.WssString(c, targetConn, string(message))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||||
@ -451,6 +479,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
||||||
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
||||||
}
|
}
|
||||||
|
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||||
|
realtimeSession := realtimeEvent.Session
|
||||||
|
if realtimeSession != nil {
|
||||||
|
// update audio format
|
||||||
|
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
||||||
|
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
|
||||||
|
localUsage.TotalTokens += textToken + audioToken
|
||||||
|
|
||||||
|
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
||||||
|
info.IsFirstRequest = false
|
||||||
|
localUsage.InputTokens += textToken + audioToken
|
||||||
|
localUsage.InputTokenDetails.TextTokens += textToken
|
||||||
|
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||||
|
} else {
|
||||||
|
localUsage.OutputTokens += textToken + audioToken
|
||||||
|
localUsage.OutputTokenDetails.TextTokens += textToken
|
||||||
|
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.WssString(c, clientConn, string(message))
|
err = service.WssString(c, clientConn, string(message))
|
||||||
@ -475,5 +529,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
case <-c.Done():
|
case <-c.Done():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check usage total tokens, if 0, use local usage
|
||||||
|
|
||||||
|
if usage.TotalTokens == 0 {
|
||||||
|
usage = localUsage
|
||||||
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
|
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -35,11 +36,18 @@ type RelayInfo struct {
|
|||||||
ShouldIncludeUsage bool
|
ShouldIncludeUsage bool
|
||||||
ClientWs *websocket.Conn
|
ClientWs *websocket.Conn
|
||||||
TargetWs *websocket.Conn
|
TargetWs *websocket.Conn
|
||||||
|
InputAudioFormat string
|
||||||
|
OutputAudioFormat string
|
||||||
|
RealtimeTools []dto.RealTimeTool
|
||||||
|
IsFirstRequest bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||||
info := GenRelayInfo(c)
|
info := GenRelayInfo(c)
|
||||||
info.ClientWs = ws
|
info.ClientWs = ws
|
||||||
|
info.InputAudioFormat = "pcm16"
|
||||||
|
info.OutputAudioFormat = "pcm16"
|
||||||
|
info.IsFirstRequest = true
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
promptTokens := 0
|
promptTokens := 0
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||||
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
|
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
@ -150,7 +150,7 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
|||||||
quota := 0
|
quota := 0
|
||||||
if !usePrice {
|
if !usePrice {
|
||||||
quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
|
quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
|
||||||
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*completionRatio*audioCompletionRatio))
|
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
|
||||||
|
|
||||||
quota = int(math.Round(float64(quota) * ratio))
|
quota = int(math.Round(float64(quota) * ratio))
|
||||||
if ratio != 0 && quota <= 0 {
|
if ratio != 0 && quota <= 0 {
|
||||||
@ -215,16 +215,16 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
|||||||
//}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
|
//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
|
||||||
var promptTokens int
|
// var promptTokens int
|
||||||
var err error
|
// var err error
|
||||||
switch info.RelayMode {
|
// switch info.RelayMode {
|
||||||
default:
|
// default:
|
||||||
promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
|
// promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
|
||||||
}
|
// }
|
||||||
info.PromptTokens = promptTokens
|
// info.PromptTokens = promptTokens
|
||||||
return promptTokens, err
|
// return promptTokens, err
|
||||||
}
|
//}
|
||||||
|
|
||||||
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
|
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
|
||||||
// var err error
|
// var err error
|
||||||
|
31
service/audio.go
Normal file
31
service/audio.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseAudio(audioBase64 string, format string) (duration float64, err error) {
|
||||||
|
audioData, err := base64.StdEncoding.DecodeString(audioBase64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("base64 decode error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var samplesCount int
|
||||||
|
var sampleRate int
|
||||||
|
|
||||||
|
switch format {
|
||||||
|
case "pcm16":
|
||||||
|
samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
|
||||||
|
sampleRate = 24000 // 24kHz
|
||||||
|
case "g711_ulaw", "g711_alaw":
|
||||||
|
samplesCount = len(audioData) // 8位 = 1字节每样本
|
||||||
|
sampleRate = 8000 // 8kHz
|
||||||
|
default:
|
||||||
|
samplesCount = len(audioData) // 8位 = 1字节每样本
|
||||||
|
sampleRate = 8000 // 8kHz
|
||||||
|
}
|
||||||
|
|
||||||
|
duration = float64(samplesCount) / float64(sampleRate)
|
||||||
|
return duration, nil
|
||||||
|
}
|
@ -48,7 +48,7 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
|
|||||||
common.LogError(c, "websocket connection is nil")
|
common.LogError(c, "websocket connection is nil")
|
||||||
return errors.New("websocket connection is nil")
|
return errors.New("websocket connection is nil")
|
||||||
}
|
}
|
||||||
common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
|
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
|
||||||
return ws.WriteMessage(1, []byte(str))
|
return ws.WriteMessage(1, []byte(str))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
|||||||
common.LogError(c, "websocket connection is nil")
|
common.LogError(c, "websocket connection is nil")
|
||||||
return errors.New("websocket connection is nil")
|
return errors.New("websocket connection is nil")
|
||||||
}
|
}
|
||||||
common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
|
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
|
||||||
return ws.WriteMessage(1, jsonData)
|
return ws.WriteMessage(1, jsonData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
@ -191,43 +192,55 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
|
|||||||
return tkm, nil
|
return tkm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) {
|
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
||||||
tkm := 0
|
audioToken := 0
|
||||||
ratio := 1
|
textToken := 0
|
||||||
|
switch request.Type {
|
||||||
|
case dto.RealtimeEventTypeSessionUpdate:
|
||||||
if request.Session != nil {
|
if request.Session != nil {
|
||||||
msgTokens, err := CountTokenText(request.Session.Instructions, model)
|
msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, 0, err
|
||||||
}
|
}
|
||||||
ratio = len(request.Session.Modalities)
|
textToken += msgTokens
|
||||||
tkm += msgTokens
|
}
|
||||||
if request.Session.Tools != nil {
|
case dto.RealtimeEventResponseAudioDelta:
|
||||||
toolsData, _ := json.Marshal(request.Session.Tools)
|
// count audio token
|
||||||
var openaiTools []dto.OpenAITools
|
atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
|
||||||
err := json.Unmarshal(toolsData, &openaiTools)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
|
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
||||||
}
|
}
|
||||||
countStr := ""
|
audioToken += atk
|
||||||
for _, tool := range openaiTools {
|
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||||||
countStr = tool.Function.Name
|
// count text token
|
||||||
if tool.Function.Description != "" {
|
tkm, err := CountTextToken(request.Delta, model)
|
||||||
countStr += tool.Function.Description
|
|
||||||
}
|
|
||||||
if tool.Function.Parameters != nil {
|
|
||||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
toolTokens, err := CountTokenInput(countStr, model)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
||||||
}
|
}
|
||||||
tkm += 8
|
textToken += tkm
|
||||||
tkm += toolTokens
|
case dto.RealtimeEventInputAudioBufferAppend:
|
||||||
|
// count audio token
|
||||||
|
atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
||||||
|
}
|
||||||
|
audioToken += atk
|
||||||
|
case dto.RealtimeEventTypeResponseDone:
|
||||||
|
// count tools token
|
||||||
|
if !info.IsFirstRequest {
|
||||||
|
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||||||
|
for _, tool := range info.RealtimeTools {
|
||||||
|
toolTokens, err := CountTokenInput(tool, model)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
textToken += 8
|
||||||
|
textToken += toolTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tkm *= ratio
|
}
|
||||||
return tkm, nil
|
}
|
||||||
|
return textToken, audioToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
||||||
@ -287,13 +300,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
|||||||
func CountTokenInput(input any, model string) (int, error) {
|
func CountTokenInput(input any, model string) (int, error) {
|
||||||
switch v := input.(type) {
|
switch v := input.(type) {
|
||||||
case string:
|
case string:
|
||||||
return CountTokenText(v, model)
|
return CountTextToken(v, model)
|
||||||
case []string:
|
case []string:
|
||||||
text := ""
|
text := ""
|
||||||
for _, s := range v {
|
for _, s := range v {
|
||||||
text += s
|
text += s
|
||||||
}
|
}
|
||||||
return CountTokenText(text, model)
|
return CountTextToken(text, model)
|
||||||
}
|
}
|
||||||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||||||
}
|
}
|
||||||
@ -315,16 +328,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
|||||||
return tokens
|
return tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountAudioToken(text string, model string) (int, error) {
|
func CountTTSToken(text string, model string) (int, error) {
|
||||||
if strings.HasPrefix(model, "tts") {
|
if strings.HasPrefix(model, "tts") {
|
||||||
return utf8.RuneCountInString(text), nil
|
return utf8.RuneCountInString(text), nil
|
||||||
} else {
|
} else {
|
||||||
return CountTokenText(text, model)
|
return CountTextToken(text, model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
|
||||||
func CountTokenText(text string, model string) (int, error) {
|
if audioBase64 == "" {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
duration, err := parseAudio(audioBase64, audioFormat)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return int(duration / 60 * 100 / 0.06), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
|
||||||
|
if audioBase64 == "" {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
duration, err := parseAudio(audioBase64, audioFormat)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return int(duration / 60 * 200 / 0.24), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//func CountAudioToken(sec float64, audioType string) {
|
||||||
|
// if audioType == "input" {
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||||
|
func CountTextToken(text string, model string) (int, error) {
|
||||||
var err error
|
var err error
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
return getTokenNum(tokenEncoder, text), err
|
return getTokenNum(tokenEncoder, text), err
|
||||||
|
@ -19,7 +19,7 @@ import (
|
|||||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = promptTokens
|
||||||
ctkm, err := CountTokenText(responseText, modeName)
|
ctkm, err := CountTextToken(responseText, modeName)
|
||||||
usage.CompletionTokens = ctkm
|
usage.CompletionTokens = ctkm
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage, err
|
return usage, err
|
||||||
|
Loading…
Reference in New Issue
Block a user