feat: 完善函数计费

This commit is contained in:
CaIon 2024-04-23 23:01:06 +08:00
parent 89ebd85503
commit 2841669246
8 changed files with 65 additions and 8 deletions

View File

@ -32,6 +32,17 @@ type GeneralOpenAIRequest struct {
TopLogProbs int `json:"top_logprobs,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"`
} }
type OpenAITools struct {
Type string `json:"type"`
Function OpenAIFunction `json:"function"`
}
type OpenAIFunction struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string { func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil { if r.Input == nil {
return nil return nil

View File

@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
if info.RelayMode == relayconstant.RelayModeEmbeddings { if info.RelayMode == relayconstant.RelayModeEmbeddings {

View File

@ -72,8 +72,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) var toolCount int
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else { } else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@ -16,9 +16,10 @@ import (
"time" "time"
) )
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive() //checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
toolCount := 0
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@ -69,6 +70,9 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.Content) responseTextBuilder.WriteString(choice.Delta.Content)
if choice.Delta.ToolCalls != nil { if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls { for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name) responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments) responseTextBuilder.WriteString(tool.Function.Arguments)
@ -82,6 +86,9 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.Content) responseTextBuilder.WriteString(choice.Delta.Content)
if choice.Delta.ToolCalls != nil { if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls { for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name) responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments) responseTextBuilder.WriteString(tool.Function.Arguments)
@ -135,10 +142,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
} }
wg.Wait() wg.Wait()
return nil, responseTextBuilder.String() return nil, responseTextBuilder.String(), toolCount
} }
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

View File

@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)

View File

@ -47,8 +47,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) var toolCount int
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@ -189,7 +189,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
checkSensitive := constant.ShouldCheckPromptSensitive() checkSensitive := constant.ShouldCheckPromptSensitive()
switch info.RelayMode { switch info.RelayMode {
case relayconstant.RelayModeChatCompletions: case relayconstant.RelayModeChatCompletions:
promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive) promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
case relayconstant.RelayModeCompletions: case relayconstant.RelayModeCompletions:
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
case relayconstant.RelayModeModerations: case relayconstant.RelayModeModerations:

View File

@ -116,6 +116,41 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
return tiles*170 + 85, nil return tiles*170 + 85, nil
} }
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
tkm := 0
msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive)
if err != nil {
return 0, err, b
}
tkm += msgTokens
if request.Tools != nil {
toolsData, _ := json.Marshal(request.Tools)
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count tools token fail: %s", err.Error())), false
}
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
if tool.Function.Description != "" {
countStr += tool.Function.Description
}
if tool.Function.Parameters != nil {
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens, err, _ := CountTokenInput(countStr, model, false)
if err != nil {
return 0, err, false
}
tkm += 8
tkm += toolTokens
}
return tkm, nil, false
}
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) { func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
//recover when panic //recover when panic
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)