mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-18 00:16:37 +08:00
feat: 完善函数计费
This commit is contained in:
parent
89ebd85503
commit
2841669246
@ -32,6 +32,17 @@ type GeneralOpenAIRequest struct {
|
|||||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpenAITools struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function OpenAIFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIFunction struct {
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Parameters any `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
if r.Input == nil {
|
if r.Input == nil {
|
||||||
return nil
|
return nil
|
||||||
|
@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
|
@ -72,8 +72,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
|
var toolCount int
|
||||||
|
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -16,9 +16,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
|
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
||||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
|
toolCount := 0
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
if atEOF && len(data) == 0 {
|
if atEOF && len(data) == 0 {
|
||||||
@ -69,6 +70,9 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||||
if choice.Delta.ToolCalls != nil {
|
if choice.Delta.ToolCalls != nil {
|
||||||
|
if len(choice.Delta.ToolCalls) > toolCount {
|
||||||
|
toolCount = len(choice.Delta.ToolCalls)
|
||||||
|
}
|
||||||
for _, tool := range choice.Delta.ToolCalls {
|
for _, tool := range choice.Delta.ToolCalls {
|
||||||
responseTextBuilder.WriteString(tool.Function.Name)
|
responseTextBuilder.WriteString(tool.Function.Name)
|
||||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||||
@ -82,6 +86,9 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||||
if choice.Delta.ToolCalls != nil {
|
if choice.Delta.ToolCalls != nil {
|
||||||
|
if len(choice.Delta.ToolCalls) > toolCount {
|
||||||
|
toolCount = len(choice.Delta.ToolCalls)
|
||||||
|
}
|
||||||
for _, tool := range choice.Delta.ToolCalls {
|
for _, tool := range choice.Delta.ToolCalls {
|
||||||
responseTextBuilder.WriteString(tool.Function.Name)
|
responseTextBuilder.WriteString(tool.Function.Name)
|
||||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||||
@ -135,10 +142,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return nil, responseTextBuilder.String()
|
return nil, responseTextBuilder.String(), toolCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
|
@ -47,8 +47,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
var toolCount int
|
||||||
|
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -189,7 +189,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
|||||||
checkSensitive := constant.ShouldCheckPromptSensitive()
|
checkSensitive := constant.ShouldCheckPromptSensitive()
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
|
promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
|
||||||
case relayconstant.RelayModeCompletions:
|
case relayconstant.RelayModeCompletions:
|
||||||
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
|
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
|
||||||
case relayconstant.RelayModeModerations:
|
case relayconstant.RelayModeModerations:
|
||||||
|
@ -116,6 +116,41 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
|||||||
return tiles*170 + 85, nil
|
return tiles*170 + 85, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
|
||||||
|
tkm := 0
|
||||||
|
msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err, b
|
||||||
|
}
|
||||||
|
tkm += msgTokens
|
||||||
|
if request.Tools != nil {
|
||||||
|
toolsData, _ := json.Marshal(request.Tools)
|
||||||
|
var openaiTools []dto.OpenAITools
|
||||||
|
err := json.Unmarshal(toolsData, &openaiTools)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.New(fmt.Sprintf("count tools token fail: %s", err.Error())), false
|
||||||
|
}
|
||||||
|
countStr := ""
|
||||||
|
for _, tool := range openaiTools {
|
||||||
|
countStr = tool.Function.Name
|
||||||
|
if tool.Function.Description != "" {
|
||||||
|
countStr += tool.Function.Description
|
||||||
|
}
|
||||||
|
if tool.Function.Parameters != nil {
|
||||||
|
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolTokens, err, _ := CountTokenInput(countStr, model, false)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err, false
|
||||||
|
}
|
||||||
|
tkm += 8
|
||||||
|
tkm += toolTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
return tkm, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
|
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
|
||||||
//recover when panic
|
//recover when panic
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
|
Loading…
Reference in New Issue
Block a user