diff --git a/common/model-ratio.go b/common/model-ratio.go index 664cd1d..b45ed29 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -42,6 +42,10 @@ var defaultModelRatio = map[string]float64{ "gpt-4o": 2.5, // $0.01 / 1K tokens "gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens "gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens + "o1-preview": 7.5, + "o1-preview-2024-09-12": 7.5, + "o1-mini": 1.5, + "o1-mini-2024-09-12": 1.5, "gpt-4o-mini": 0.075, "gpt-4o-mini-2024-07-18": 0.075, "gpt-4-turbo": 5, // $0.01 / 1K tokens @@ -106,8 +110,10 @@ var defaultModelRatio = map[string]float64{ "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-1.0-pro-vision-001": 1, "gemini-1.0-pro-001": 1, - "gemini-1.5-pro-latest": 1, + "gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens + "gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens "gemini-1.5-flash-latest": 1, + "gemini-1.5-flash-exp-0827": 1, "gemini-1.0-pro-latest": 1, "gemini-1.0-pro-vision-latest": 1, "gemini-ultra": 1, @@ -329,17 +335,6 @@ func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "gpt-4o-gizmo") { name = "gpt-4o-gizmo-*" } - if strings.HasPrefix(name, "gpt-3.5") { - if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { - // https://openai.com/blog/new-embedding-models-and-api-updates - // Updated GPT-3.5 Turbo model and lower pricing - return 3 - } - if strings.HasSuffix(name, "1106") { - return 2 - } - return 4.0 / 3.0 - } if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") { if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") { return 3 @@ -352,6 +347,9 @@ func GetCompletionRatio(name string) float64 { } return 2 } + if strings.HasPrefix(name, "o1-") { + return 4 + } if name == "chatgpt-4o-latest" { return 3 } @@ -362,6 +360,17 @@ func GetCompletionRatio(name string) float64 { } else if strings.Contains(name, "claude-3") { return 5 } + if strings.HasPrefix(name, "gpt-3.5") { + if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { + // https://openai.com/blog/new-embedding-models-and-api-updates + // Updated GPT-3.5 Turbo model and lower pricing + return 3 + } + if strings.HasSuffix(name, "1106") { + return 2 + } + return 4.0 / 3.0 + } if strings.HasPrefix(name, "mistral-") { return 3 } diff --git a/constant/env.go b/constant/env.go index fc9fd8b..6bab4be 100644 --- a/constant/env.go +++ b/constant/env.go @@ -20,14 +20,16 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) var GeminiModelMap = map[string]string{ - "gemini-1.5-pro-latest": "v1beta", - "gemini-1.5-pro-001": "v1beta", - "gemini-1.5-pro": "v1beta", - "gemini-1.5-pro-exp-0801": "v1beta", - "gemini-1.5-flash-latest": "v1beta", - "gemini-1.5-flash-001": "v1beta", - "gemini-1.5-flash": "v1beta", - "gemini-ultra": "v1beta", + "gemini-1.5-pro-latest": "v1beta", + "gemini-1.5-pro-001": "v1beta", + "gemini-1.5-pro": "v1beta", + "gemini-1.5-pro-exp-0801": "v1beta", + "gemini-1.5-pro-exp-0827": "v1beta", + "gemini-1.5-flash-latest": "v1beta", + "gemini-1.5-flash-exp-0827": "v1beta", + "gemini-1.5-flash-001": "v1beta", + "gemini-1.5-flash": "v1beta", + "gemini-ultra": "v1beta", } func InitEnv() { diff --git a/controller/channel-test.go b/controller/channel-test.go index 95c4a60..ff66386 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,7 @@ import ( "one-api/relay/constant" "one-api/service" "strconv" + "strings" "sync" "time" @@ -81,8 +82,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } - request := buildTestRequest() - request.Model = testModel + request := buildTestRequest(testModel) meta.UpstreamModelName = testModel common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) @@ -141,17 +141,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return nil, nil } -func buildTestRequest() *dto.GeneralOpenAIRequest { +func buildTestRequest(model string) *dto.GeneralOpenAIRequest { testRequest := &dto.GeneralOpenAIRequest{ - Model: "", // this will be set later - MaxTokens: 1, - Stream: false, + Model: "", // this will be set later + Stream: false, + } + if strings.HasPrefix(model, "o1-") { + testRequest.MaxCompletionTokens = 1 + } else { + testRequest.MaxTokens = 1 } content, _ := json.Marshal("hi") testMessage := dto.Message{ Role: "user", Content: content, } + testRequest.Model = model testRequest.Messages = append(testRequest.Messages, testMessage) return testRequest } @@ -226,26 +231,22 @@ func testAllChannels(notify bool) error { tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() - ban := false - if milliseconds > disableThreshold { - err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - ban = true - } + shouldBanChannel := false // request error disables the channel if openaiWithStatusErr != nil { oaiErr := openaiWithStatusErr.Error err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) - ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) + shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) } - // parse *int to bool - if !channel.GetAutoBan() { - ban = false + if milliseconds > disableThreshold { + err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + shouldBanChannel = true } // disable channel - if ban && isChannelEnabled { + if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { service.DisableChannel(channel.Id, channel.Name, err.Error()) } diff --git a/dto/text_request.go b/dto/text_request.go index a804e63..0fdfc03 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -7,31 +7,32 @@ type ResponseFormat struct { } type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - StreamOptions *StreamOptions `json:"stream_options,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat any `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Tools []ToolCall `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` - LogProbs bool `json:"logprobs,omitempty"` - TopLogProbs int `json:"top_logprobs,omitempty"` - Dimensions int `json:"dimensions,omitempty"` + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Stop any `json:"stop,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + Tools []ToolCall `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` + LogProbs bool `json:"logprobs,omitempty"` + TopLogProbs int `json:"top_logprobs,omitempty"` + Dimensions int `json:"dimensions,omitempty"` } type OpenAITools struct { diff --git a/dto/text_response.go b/dto/text_response.go index 9b12683..5d13773 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -34,6 +34,7 @@ type OpenAITextResponseChoice struct { type OpenAITextResponse struct { Id string `json:"id"` + Model string `json:"model"` Object string `json:"object"` Created int64 `json:"created"` Choices []OpenAITextResponseChoice `json:"choices"` diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 1923e35..781b9a7 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -4,7 +4,6 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -12,6 +11,8 @@ import ( relaycommon "one-api/relay/common" "one-api/service" "strings" + + "github.com/gin-gonic/gin" ) func stopReasonClaude2OpenAI(reason string) string { @@ -108,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR } } formatMessages := make([]dto.Message, 0) - var lastMessage *dto.Message + lastMessage := dto.Message{ + Role: "tool", + } for i, message := range textRequest.Messages { - //if message.Role == "system" { - // if i != 0 { - // message.Role = "user" - // } - //} if message.Role == "" { textRequest.Messages[i].Role = "user" } @@ -122,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR Role: message.Role, Content: message.Content, } - if lastMessage != nil && lastMessage.Role == message.Role { + if message.Role == "tool" { + fmtMessage.ToolCallId = message.ToolCallId + } + if message.Role == "assistant" && message.ToolCalls != nil { + fmtMessage.ToolCalls = message.ToolCalls + } + if lastMessage.Role == message.Role && lastMessage.Role != "tool" { if lastMessage.IsStringContent() && message.IsStringContent() { content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\"")) fmtMessage.Content = content @@ -135,7 +139,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR fmtMessage.Content = content } formatMessages = append(formatMessages, fmtMessage) - lastMessage = &textRequest.Messages[i] + lastMessage = fmtMessage } claudeMessages := make([]ClaudeMessage, 0) @@ -174,7 +178,35 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR claudeMessage := ClaudeMessage{ Role: message.Role, } - if message.IsStringContent() { + if message.Role == "tool" { + if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" { + lastMessage := claudeMessages[len(claudeMessages)-1] + if content, ok := lastMessage.Content.(string); ok { + lastMessage.Content = []ClaudeMediaMessage{ + { + Type: "text", + Text: content, + }, + } + } + lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{ + Type: "tool_result", + ToolUseId: message.ToolCallId, + Content: message.StringContent(), + }) + claudeMessages[len(claudeMessages)-1] = lastMessage + continue + } else { + claudeMessage.Role = "user" + claudeMessage.Content = []ClaudeMediaMessage{ + { + Type: "tool_result", + ToolUseId: message.ToolCallId, + Content: message.StringContent(), + }, + } + } + } else if message.IsStringContent() && message.ToolCalls == nil { claudeMessage.Content = message.StringContent() } else { claudeMediaMessages := make([]ClaudeMediaMessage, 0) @@ -207,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR } claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) } + if message.ToolCalls != nil { + for _, tc := range message.ToolCalls.([]interface{}) { + toolCallJSON, _ := json.Marshal(tc) + var toolCall dto.ToolCall + err := json.Unmarshal(toolCallJSON, &toolCall) + if err != nil { + common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc)) + continue + } + inputObj := make(map[string]any) + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { + common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) + continue + } + claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{ + Type: "tool_use", + Id: toolCall.ID, + Name: toolCall.Function.Name, + Input: inputObj, + }) + } + } claudeMessage.Content = claudeMediaMessages } claudeMessages = append(claudeMessages, claudeMessage) @@ -341,6 +395,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope if len(tools) > 0 { choice.Message.ToolCalls = tools } + fullTextResponse.Model = claudeResponse.Model choices = append(choices, choice) fullTextResponse.Choices = choices return &fullTextResponse diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go index 7f50a15..e7452fd 100644 --- a/relay/channel/cohere/dto.go +++ b/relay/channel/cohere/dto.go @@ -8,7 +8,7 @@ type CohereRequest struct { Message string `json:"message"` Stream bool `json:"stream"` MaxTokens int `json:"max_tokens"` - SafetyMode string `json:"safety_mode"` + SafetyMode string `json:"safety_mode,omitempty"` } type ChatHistory struct { diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index adec316..132039b 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -22,7 +22,9 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest { Message: "", Stream: textRequest.Stream, MaxTokens: textRequest.GetMaxTokens(), - SafetyMode: common.CohereSafetySetting, + } + if common.CohereSafetySetting != "NONE" { + cohereReq.SafetyMode = common.CohereSafetySetting } if cohereReq.MaxTokens == 0 { cohereReq.MaxTokens = 4000 diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go index 621336b..4a2e4dd 100644 --- a/relay/channel/gemini/constant.go +++ b/relay/channel/gemini/constant.go @@ -6,7 +6,7 @@ const ( var ModelList = []string{ "gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra", - "gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", + "gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827", } var ChannelName = "google gemini" diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index fac6b7f..ed4d5ca 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -17,11 +17,25 @@ type OllamaRequest struct { PresencePenalty float64 `json:"presence_penalty,omitempty"` } +type Options struct { + Seed int `json:"seed,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` +} + type OllamaEmbeddingRequest struct { - Model string `json:"model,omitempty"` - Prompt any `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + Input []string `json:"input"` + Options *Options `json:"options,omitempty"` } type OllamaEmbeddingResponse struct { + Error string `json:"error,omitempty"` + Model string `json:"model"` Embedding []float64 `json:"embedding,omitempty"` } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 6bf395a..b2d4630 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -9,7 +9,6 @@ import ( "net/http" "one-api/dto" "one-api/service" - "strings" ) func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { @@ -45,8 +44,15 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { return &OllamaEmbeddingRequest{ - Model: request.Model, - Prompt: strings.Join(request.ParseInput(), " "), + Model: request.Model, + Input: request.ParseInput(), + Options: &Options{ + Seed: int(request.Seed), + Temperature: request.Temperature, + TopP: request.TopP, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + }, } } @@ -64,6 +70,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } + if ollamaEmbeddingResponse.Error != "" { + return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil + } data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) data = append(data, dto.OpenAIEmbeddingResponseItem{ Embedding: ollamaEmbeddingResponse.Embedding, diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 4388efd..8e4cf78 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -78,6 +78,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re if info.ChannelType != common.ChannelTypeOpenAI { request.StreamOptions = nil } + if strings.HasPrefix(request.Model, "o1-") { + if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 { + request.MaxCompletionTokens = request.MaxTokens + request.MaxTokens = 0 + } + } return request, nil } diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go index ac2d673..5e5fd21 100644 --- a/relay/channel/openai/constant.go +++ b/relay/channel/openai/constant.go @@ -11,6 +11,8 @@ var ModelList = []string{ "chatgpt-4o-latest", "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", + "o1-preview", "o1-preview-2024-09-12", + "o1-mini", "o1-mini-2024-09-12", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-curie-001", "text-babbage-001", "text-ada-001", "text-moderation-latest", "text-moderation-stable", diff --git a/web/src/components/Footer.js b/web/src/components/Footer.js index 0e71b72..7b80ac7 100644 --- a/web/src/components/Footer.js +++ b/web/src/components/Footer.js @@ -59,12 +59,10 @@ const Footer = () => { {footer ? ( - -
-
+
) : ( defaultFooter )}