feat: realtime

(cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425)
This commit is contained in:
1808837298@qq.com 2024-10-06 14:13:41 +08:00 committed by CalciumIon
parent 33af069fae
commit 74f9006b40
15 changed files with 227 additions and 62 deletions

View File

@ -432,9 +432,23 @@ func GetAudioCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4o-realtime") { if strings.HasPrefix(name, "gpt-4o-realtime") {
return 10 return 10
} }
return 10 return 2
} }
//func GetAudioPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.06
// }
// return 0.06
//}
//
//func GetAudioCompletionPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.24
// }
// return 0.24
//}
func GetCompletionRatioMap() map[string]float64 { func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil { if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio CompletionRatio = defaultCompletionRatio

View File

@ -5,10 +5,18 @@ const (
RealtimeEventTypeSessionUpdate = "session.update" RealtimeEventTypeSessionUpdate = "session.update"
RealtimeEventTypeConversationCreate = "conversation.item.create" RealtimeEventTypeConversationCreate = "conversation.item.create"
RealtimeEventTypeResponseCreate = "response.create" RealtimeEventTypeResponseCreate = "response.create"
RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
) )
const ( const (
RealtimeEventTypeResponseDone = "response.done" RealtimeEventTypeResponseDone = "response.done"
RealtimeEventTypeSessionUpdated = "session.updated"
RealtimeEventTypeSessionCreated = "session.created"
RealtimeEventResponseAudioDelta = "response.audio.delta"
RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta"
RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done"
RealtimeEventConversationItemCreated = "conversation.item.created"
) )
type RealtimeEvent struct { type RealtimeEvent struct {
@ -19,6 +27,8 @@ type RealtimeEvent struct {
Item *RealtimeItem `json:"item,omitempty"` Item *RealtimeItem `json:"item,omitempty"`
Error *OpenAIError `json:"error,omitempty"` Error *OpenAIError `json:"error,omitempty"`
Response *RealtimeResponse `json:"response,omitempty"` Response *RealtimeResponse `json:"response,omitempty"`
Delta string `json:"delta,omitempty"`
Audio string `json:"audio,omitempty"`
} }
type RealtimeResponse struct { type RealtimeResponse struct {

View File

@ -509,7 +509,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
}, nil }, nil
} }
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName) completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
} }

View File

@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName) usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage return nil, usage

View File

@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
} }
if usage.TotalTokens == 0 { if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText) usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
} }
return nil, usage return nil, usage

View File

@ -47,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
model_ := info.UpstreamModelName model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1) model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67 // https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax: case common.ChannelTypeMiniMax:
return minimax.GetRequestURL(info) return minimax.GetRequestURL(info)

View File

@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"io" "io"
"log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
@ -232,7 +233,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range simpleResponse.Choices { for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model) ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
completionTokens += ctkm completionTokens += ctkm
} }
simpleResponse.Usage = dto.Usage{ simpleResponse.Usage = dto.Usage{
@ -325,7 +326,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage return nil, usage
} }
@ -387,6 +388,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
errChan := make(chan error, 2) errChan := make(chan error, 2)
usage := &dto.RealtimeUsage{} usage := &dto.RealtimeUsage{}
localUsage := &dto.RealtimeUsage{}
go func() { go func() {
for { for {
@ -403,6 +405,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
return return
} }
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
if realtimeEvent.Session != nil {
if realtimeEvent.Session.Tools != nil {
info.RealtimeTools = realtimeEvent.Session.Tools
}
}
}
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = service.WssString(c, targetConn, string(message)) err = service.WssString(c, targetConn, string(message))
if err != nil { if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err) errChan <- fmt.Errorf("error writing to target: %v", err)
@ -451,6 +479,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
} }
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
if realtimeSession != nil {
// update audio format
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
localUsage.TotalTokens += textToken + audioToken
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
} else {
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
} }
err = service.WssString(c, clientConn, string(message)) err = service.WssString(c, clientConn, string(message))
@ -475,5 +529,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
case <-c.Done(): case <-c.Done():
} }
// check usage total tokens, if 0, use local usage
if usage.TotalTokens == 0 {
usage = localUsage
}
return nil, usage return nil, usage
} }

View File

@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil }, nil
} }
fullTextResponse := responsePaLM2OpenAI(&palmResponse) fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model) completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{ usage := dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,

View File

@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"one-api/common" "one-api/common"
"one-api/dto"
"one-api/relay/constant" "one-api/relay/constant"
"strings" "strings"
"time" "time"
@ -35,11 +36,18 @@ type RelayInfo struct {
ShouldIncludeUsage bool ShouldIncludeUsage bool
ClientWs *websocket.Conn ClientWs *websocket.Conn
TargetWs *websocket.Conn TargetWs *websocket.Conn
InputAudioFormat string
OutputAudioFormat string
RealtimeTools []dto.RealTimeTool
IsFirstRequest bool
} }
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c) info := GenRelayInfo(c)
info.ClientWs = ws info.ClientWs = ws
info.InputAudioFormat = "pcm16"
info.OutputAudioFormat = "pcm16"
info.IsFirstRequest = true
return info return info
} }

View File

@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
promptTokens := 0 promptTokens := 0
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model) promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
} }

View File

@ -150,7 +150,7 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
quota := 0 quota := 0
if !usePrice { if !usePrice {
quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*completionRatio*audioCompletionRatio)) quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
quota = int(math.Round(float64(quota) * ratio)) quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {
@ -215,16 +215,16 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
//} //}
} }
func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) { //func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int // var promptTokens int
var err error // var err error
switch info.RelayMode { // switch info.RelayMode {
default: // default:
promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName) // promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
} // }
info.PromptTokens = promptTokens // info.PromptTokens = promptTokens
return promptTokens, err // return promptTokens, err
} //}
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error { //func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
// var err error // var err error

31
service/audio.go Normal file
View File

@ -0,0 +1,31 @@
package service
import (
"encoding/base64"
"fmt"
)
func parseAudio(audioBase64 string, format string) (duration float64, err error) {
audioData, err := base64.StdEncoding.DecodeString(audioBase64)
if err != nil {
return 0, fmt.Errorf("base64 decode error: %v", err)
}
var samplesCount int
var sampleRate int
switch format {
case "pcm16":
samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
sampleRate = 24000 // 24kHz
case "g711_ulaw", "g711_alaw":
samplesCount = len(audioData) // 8位 = 1字节每样本
sampleRate = 8000 // 8kHz
default:
samplesCount = len(audioData) // 8位 = 1字节每样本
sampleRate = 8000 // 8kHz
}
duration = float64(samplesCount) / float64(sampleRate)
return duration, nil
}

View File

@ -48,7 +48,7 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
common.LogError(c, "websocket connection is nil") common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil") return errors.New("websocket connection is nil")
} }
common.LogInfo(c, fmt.Sprintf("sending message: %s", str)) //common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
return ws.WriteMessage(1, []byte(str)) return ws.WriteMessage(1, []byte(str))
} }
@ -61,7 +61,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
common.LogError(c, "websocket connection is nil") common.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil") return errors.New("websocket connection is nil")
} }
common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData)) //common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
return ws.WriteMessage(1, jsonData) return ws.WriteMessage(1, jsonData)
} }

View File

@ -11,6 +11,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
) )
@ -191,43 +192,55 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
return tkm, nil return tkm, nil
} }
func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) { func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
tkm := 0 audioToken := 0
ratio := 1 textToken := 0
if request.Session != nil { switch request.Type {
msgTokens, err := CountTokenText(request.Session.Instructions, model) case dto.RealtimeEventTypeSessionUpdate:
if err != nil { if request.Session != nil {
return 0, err msgTokens, err := CountTextToken(request.Session.Instructions, model)
if err != nil {
return 0, 0, err
}
textToken += msgTokens
} }
ratio = len(request.Session.Modalities) case dto.RealtimeEventResponseAudioDelta:
tkm += msgTokens // count audio token
if request.Session.Tools != nil { atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
toolsData, _ := json.Marshal(request.Session.Tools) if err != nil {
var openaiTools []dto.OpenAITools return 0, 0, fmt.Errorf("error counting audio token: %v", err)
err := json.Unmarshal(toolsData, &openaiTools) }
if err != nil { audioToken += atk
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())) case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
} // count text token
countStr := "" tkm, err := CountTextToken(request.Delta, model)
for _, tool := range openaiTools { if err != nil {
countStr = tool.Function.Name return 0, 0, fmt.Errorf("error counting text token: %v", err)
if tool.Function.Description != "" { }
countStr += tool.Function.Description textToken += tkm
} case dto.RealtimeEventInputAudioBufferAppend:
if tool.Function.Parameters != nil { // count audio token
countStr += fmt.Sprintf("%v", tool.Function.Parameters) atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
if err != nil {
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
}
audioToken += atk
case dto.RealtimeEventTypeResponseDone:
// count tools token
if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools {
toolTokens, err := CountTokenInput(tool, model)
if err != nil {
return 0, 0, err
}
textToken += 8
textToken += toolTokens
} }
} }
toolTokens, err := CountTokenInput(countStr, model)
if err != nil {
return 0, err
}
tkm += 8
tkm += toolTokens
} }
} }
tkm *= ratio return textToken, audioToken, nil
return tkm, nil
} }
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) { func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
@ -287,13 +300,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
func CountTokenInput(input any, model string) (int, error) { func CountTokenInput(input any, model string) (int, error) {
switch v := input.(type) { switch v := input.(type) {
case string: case string:
return CountTokenText(v, model) return CountTextToken(v, model)
case []string: case []string:
text := "" text := ""
for _, s := range v { for _, s := range v {
text += s text += s
} }
return CountTokenText(text, model) return CountTextToken(text, model)
} }
return CountTokenInput(fmt.Sprintf("%v", input), model) return CountTokenInput(fmt.Sprintf("%v", input), model)
} }
@ -315,16 +328,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens return tokens
} }
func CountAudioToken(text string, model string) (int, error) { func CountTTSToken(text string, model string) (int, error) {
if strings.HasPrefix(model, "tts") { if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text), nil return utf8.RuneCountInString(text), nil
} else { } else {
return CountTokenText(text, model) return CountTextToken(text, model)
} }
} }
// CountTokenText 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量 func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
func CountTokenText(text string, model string) (int, error) { if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 100 / 0.06), nil
}
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
if audioBase64 == "" {
return 0, nil
}
duration, err := parseAudio(audioBase64, audioFormat)
if err != nil {
return 0, err
}
return int(duration / 60 * 200 / 0.24), nil
}
//func CountAudioToken(sec float64, audioType string) {
// if audioType == "input" {
//
// }
//}
// CountTextToken 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTextToken(text string, model string) (int, error) {
var err error var err error
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err return getTokenNum(tokenEncoder, text), err

View File

@ -19,7 +19,7 @@ import (
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = promptTokens usage.PromptTokens = promptTokens
ctkm, err := CountTokenText(responseText, modeName) ctkm, err := CountTextToken(responseText, modeName)
usage.CompletionTokens = ctkm usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err return usage, err