From 5b18cd6b0ad5c96978314e31e0e51bbef6097c18 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Thu, 29 Feb 2024 01:08:18 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/utils.go | 2 +- controller/channel-billing.go | 5 +- controller/channel-test.go | 160 ++-- controller/midjourney.go | 10 +- controller/model.go | 1 - controller/relay-aiproxy.go | 220 ----- controller/relay-text.go | 752 ------------------ controller/relay.go | 374 +-------- dto/error.go | 13 + dto/request.go | 137 ++++ dto/response.go | 86 ++ main.go | 3 +- middleware/distributor.go | 7 +- relay/channel/adapter.go | 57 ++ relay/channel/ali/adaptor.go | 80 ++ relay/channel/ali/constants.go | 8 + relay/channel/ali/dto.go | 70 ++ .../channel/ali}/relay-ali.go | 141 +--- relay/channel/api_request.go | 52 ++ relay/channel/baidu/adaptor.go | 92 +++ relay/channel/baidu/constants.go | 12 + relay/channel/baidu/dto.go | 71 ++ .../channel/baidu}/relay-baidu.go | 140 +--- relay/channel/claude/adaptor.go | 65 ++ relay/channel/claude/constants.go | 7 + relay/channel/claude/dto.go | 29 + .../channel/claude}/relay-claude.go | 76 +- relay/channel/gemini/adaptor.go | 64 ++ relay/channel/gemini/constant.go | 12 + relay/channel/gemini/dto.go | 62 ++ .../channel/gemini}/relay-gemini.go | 130 +-- relay/channel/moonshot/constants.go | 7 + relay/channel/openai/adaptor.go | 84 ++ relay/channel/openai/constant.go | 21 + .../channel/openai}/relay-openai.go | 39 +- relay/channel/palm/adaptor.go | 59 ++ relay/channel/palm/constants.go | 7 + relay/channel/palm/dto.go | 38 + .../channel/palm}/relay-palm.go | 86 +- relay/channel/tencent/adaptor.go | 73 ++ relay/channel/tencent/constants.go | 9 + relay/channel/tencent/dto.go | 61 ++ .../channel/tencent}/relay-tencent.go | 101 +-- relay/channel/xunfei/adaptor.go | 68 ++ relay/channel/xunfei/constants.go | 11 + relay/channel/xunfei/dto.go | 59 ++ .../channel/xunfei}/relay-xunfei.go | 103 +-- relay/channel/zhipu/adaptor.go | 61 ++ relay/channel/zhipu/constants.go | 7 + relay/channel/zhipu/dto.go | 46 ++ .../channel/zhipu}/relay-zhipu.go | 97 +-- relay/common/relay_info.go | 71 ++ relay/common/relay_utils.go | 68 ++ relay/constant/api_type.go | 45 ++ relay/constant/relay_mode.go | 50 ++ {controller => relay}/relay-audio.go | 57 +- {controller => relay}/relay-image.go | 13 +- {controller => relay}/relay-mj.go | 19 +- relay/relay-text.go | 277 +++++++ router/relay-router.go | 6 +- service/channel.go | 53 ++ service/error.go | 29 + service/http_client.go | 32 + service/sse.go | 11 + .../token_counter.go | 139 +--- service/usage_helpr.go | 27 + service/user_notify.go | 17 + 67 files changed, 2646 insertions(+), 2243 deletions(-) delete mode 100644 controller/relay-aiproxy.go delete mode 100644 controller/relay-text.go create mode 100644 dto/error.go create mode 100644 dto/request.go create mode 100644 dto/response.go create mode 100644 relay/channel/adapter.go create mode 100644 relay/channel/ali/adaptor.go create mode 100644 relay/channel/ali/constants.go create mode 100644 relay/channel/ali/dto.go rename {controller => relay/channel/ali}/relay-ali.go (61%) create mode 100644 relay/channel/api_request.go create mode 100644 relay/channel/baidu/adaptor.go create mode 100644 relay/channel/baidu/constants.go create mode 100644 relay/channel/baidu/dto.go rename {controller => relay/channel/baidu}/relay-baidu.go (63%) create mode 100644 relay/channel/claude/adaptor.go create mode 100644 relay/channel/claude/constants.go create mode 100644 relay/channel/claude/dto.go rename {controller => relay/channel/claude}/relay-claude.go (68%) create mode 100644 relay/channel/gemini/adaptor.go create mode 100644 relay/channel/gemini/constant.go create mode 100644 relay/channel/gemini/dto.go rename {controller => relay/channel/gemini}/relay-gemini.go (62%) create mode 100644 relay/channel/moonshot/constants.go create mode 100644 relay/channel/openai/adaptor.go create mode 100644 relay/channel/openai/constant.go rename {controller => relay/channel/openai}/relay-openai.go (74%) create mode 100644 relay/channel/palm/adaptor.go create mode 100644 relay/channel/palm/constants.go create mode 100644 relay/channel/palm/dto.go rename {controller => relay/channel/palm}/relay-palm.go (63%) create mode 100644 relay/channel/tencent/adaptor.go create mode 100644 relay/channel/tencent/constants.go create mode 100644 relay/channel/tencent/dto.go rename {controller => relay/channel/tencent}/relay-tencent.go (57%) create mode 100644 relay/channel/xunfei/adaptor.go create mode 100644 relay/channel/xunfei/constants.go create mode 100644 relay/channel/xunfei/dto.go rename {controller => relay/channel/xunfei}/relay-xunfei.go (69%) create mode 100644 relay/channel/zhipu/adaptor.go create mode 100644 relay/channel/zhipu/constants.go create mode 100644 relay/channel/zhipu/dto.go rename {controller => relay/channel/zhipu}/relay-zhipu.go (67%) create mode 100644 relay/common/relay_info.go create mode 100644 relay/common/relay_utils.go create mode 100644 relay/constant/api_type.go create mode 100644 relay/constant/relay_mode.go rename {controller => relay}/relay-audio.go (67%) rename {controller => relay}/relay-image.go (96%) rename {controller => relay}/relay-mj.go (97%) create mode 100644 relay/relay-text.go create mode 100644 service/channel.go create mode 100644 service/error.go create mode 100644 service/http_client.go create mode 100644 service/sse.go rename controller/relay-utils.go => service/token_counter.go (61%) create mode 100644 service/usage_helpr.go create mode 100644 service/user_notify.go diff --git a/common/utils.go b/common/utils.go index 7a0149d..eb6678a 100644 --- a/common/utils.go +++ b/common/utils.go @@ -230,7 +230,7 @@ func StringsContains(strs []string, str string) bool { return false } -// []byte only read, panic on append +// StringToByteSlice []byte only read, panic on append func StringToByteSlice(s string) []byte { tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 30c9cbc..1bbd8db 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "one-api/service" "strconv" "time" @@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He for k := range headers { req.Header.Add(k, headers.Get(k)) } - res, err := httpClient.Do(req) + res, err := service.GetHttpClient().Do(req) if err != nil { return nil, err } @@ -310,7 +311,7 @@ func updateAllChannelsBalance() error { } else { // err is nil & balance <= 0 means quota is used up if balance <= 0 { - disableChannel(channel.Id, channel.Name, "余额不足") + service.DisableChannel(channel.Id, channel.Name, "余额不足") } } time.Sleep(common.RequestInterval) diff --git a/controller/channel-test.go b/controller/channel-test.go index 3afd4b8..1ea767a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,9 +5,17 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" + "net/http/httptest" + "net/url" "one-api/common" + "one-api/dto" "one-api/model" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/relay/constant" + "one-api/service" "strconv" "sync" "time" @@ -15,89 +23,77 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { - common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, request.Model)) - switch channel.Type { - case common.ChannelTypePaLM: - fallthrough - case common.ChannelTypeAnthropic: - fallthrough - case common.ChannelTypeBaidu: - fallthrough - case common.ChannelTypeZhipu: - fallthrough - case common.ChannelTypeAli: - fallthrough - case common.ChannelType360: - fallthrough - case common.ChannelTypeGemini: - fallthrough - case common.ChannelTypeXunfei: - return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil - case common.ChannelTypeAzure: - if request.Model == "" { - request.Model = "gpt-35-turbo" - } - defer func() { - if err != nil { - err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") - } - }() - default: - if request.Model == "" { - request.Model = "gpt-3.5-turbo" - } +func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) { + common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = &http.Request{ + Method: "POST", + URL: &url.URL{Path: "/v1/chat/completions"}, + Body: nil, + Header: make(http.Header), } - baseUrl := common.ChannelBaseURLs[channel.Type] - if channel.GetBaseURL() != "" { - baseUrl = channel.GetBaseURL() + c.Request.Header.Set("Authorization", "Bearer "+channel.Key) + c.Request.Header.Set("Content-Type", "application/json") + c.Set("channel", channel.Type) + c.Set("base_url", channel.GetBaseURL()) + meta := relaycommon.GenRelayInfo(c) + apiType := constant.ChannelType2APIType(channel.Type) + adaptor := relaychannel.GetAdaptor(apiType) + if adaptor == nil { + return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } - requestURL := getFullRequestURL(baseUrl, "/v1/chat/completions", channel.Type) + if testModel == "" { + testModel = adaptor.GetModelList()[0] + } + request := buildTestRequest() - if channel.Type == common.ChannelTypeAzure { - requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) - } + adaptor.Init(meta, *request) - jsonData, err := json.Marshal(request) + request.Model = testModel + meta.UpstreamModelName = testModel + convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) if err != nil { return err, nil } - req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) + jsonData, err := json.Marshal(convertedRequest) if err != nil { return err, nil } - if channel.Type == common.ChannelTypeAzure { - req.Header.Set("api-key", channel.Key) - } else { - req.Header.Set("Authorization", "Bearer "+channel.Key) - } - req.Header.Set("Content-Type", "application/json") - resp, err := httpClient.Do(req) + requestBody := bytes.NewBuffer(jsonData) + c.Request.Body = io.NopCloser(requestBody) + resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { return err, nil } - defer resp.Body.Close() - var response TextResponse - err = json.NewDecoder(resp.Body).Decode(&response) + if resp.StatusCode != http.StatusOK { + err := relaycommon.RelayErrorHandler(resp) + return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError + } + usage, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError + } + if usage == nil { + return errors.New("usage is nil"), nil + } + result := w.Result() + // print result.Body + respBody, err := io.ReadAll(result.Body) if err != nil { return err, nil } - if response.Usage.CompletionTokens == 0 { - if response.Error.Message == "" { - response.Error.Message = "补全 tokens 非预期返回 0" - } - return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error - } + common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return nil, nil } -func buildTestRequest() *ChatRequest { - testRequest := &ChatRequest{ +func buildTestRequest() *dto.GeneralOpenAIRequest { + testRequest := &dto.GeneralOpenAIRequest{ Model: "", // this will be set later MaxTokens: 1, } content, _ := json.Marshal("hi") - testMessage := Message{ + testMessage := dto.Message{ Role: "user", Content: content, } @@ -114,7 +110,6 @@ func TestChannel(c *gin.Context) { }) return } - testModel := c.Query("model") channel, err := model.GetChannelById(id, true) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -123,12 +118,9 @@ func TestChannel(c *gin.Context) { }) return } - testRequest := buildTestRequest() - if testModel != "" { - testRequest.Model = testModel - } + testModel := c.Query("model") tik := time.Now() - err, _ = testChannel(channel, *testRequest) + err, _ = testChannel(channel, testModel) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() go channel.UpdateResponseTime(milliseconds) @@ -152,31 +144,6 @@ func TestChannel(c *gin.Context) { var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false -// disable & notify -func disableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - notifyRootUser(subject, content) -} - -func enableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - notifyRootUser(subject, content) -} - -func notifyRootUser(subject string, content string) { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() - } - err := common.SendEmail(subject, common.RootUserEmail, content) - if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } -} - func testAllChannels(notify bool) error { if common.RootUserEmail == "" { common.RootUserEmail = model.GetRootUserEmail() @@ -192,7 +159,6 @@ func testAllChannels(notify bool) error { if err != nil { return err } - testRequest := buildTestRequest() var disableThreshold = int64(common.ChannelDisableThreshold * 1000) if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value @@ -201,7 +167,7 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - err, openaiErr := testChannel(channel, *testRequest) + err, openaiErr := testChannel(channel, "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() @@ -218,11 +184,11 @@ func testAllChannels(notify bool) error { if channel.AutoBan != nil && *channel.AutoBan == 0 { ban = false } - if isChannelEnabled && shouldDisableChannel(openaiErr, -1) && ban { - disableChannel(channel.Id, channel.Name, err.Error()) + if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban { + service.DisableChannel(channel.Id, channel.Name, err.Error()) } - if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { - enableChannel(channel.Id, channel.Name) + if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) { + service.EnableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) time.Sleep(common.RequestInterval) diff --git a/controller/midjourney.go b/controller/midjourney.go index 14d4746..0e28efc 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -10,7 +10,9 @@ import ( "log" "net/http" "one-api/common" + "one-api/controller/relay" "one-api/model" + relay2 "one-api/relay" "strconv" "strings" "time" @@ -63,7 +65,7 @@ import ( req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/json") - //req.Header.Set("Authorization", "Bearer midjourney-proxy") + //req.Header.Set("ApiKey", "Bearer midjourney-proxy") req.Header.Set("mj-api-secret", midjourneyChannel.Key) resp, err := httpClient.Do(req) if err != nil { @@ -221,7 +223,7 @@ func UpdateMidjourneyTaskBulk() { req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/json") req.Header.Set("mj-api-secret", midjourneyChannel.Key) - resp, err := httpClient.Do(req) + resp, err := relay.httpClient.Do(req) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue @@ -231,7 +233,7 @@ func UpdateMidjourneyTaskBulk() { common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) continue } - var responseItems []Midjourney + var responseItems []relay2.Midjourney err = json.Unmarshal(responseBody, &responseItems) if err != nil { common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) @@ -284,7 +286,7 @@ func UpdateMidjourneyTaskBulk() { } } -func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool { +func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool { if oldTask.Code != 1 { return true } diff --git a/controller/model.go b/controller/model.go index 39c4acf..1721cb7 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,7 +2,6 @@ package controller import ( "fmt" - "github.com/gin-gonic/gin" ) diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go deleted file mode 100644 index 7dbf679..0000000 --- a/controller/relay-aiproxy.go +++ /dev/null @@ -1,220 +0,0 @@ -package controller - -import ( - "bufio" - "encoding/json" - "fmt" - "github.com/gin-gonic/gin" - "io" - "net/http" - "one-api/common" - "strconv" - "strings" -) - -// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 - -type AIProxyLibraryRequest struct { - Model string `json:"model"` - Query string `json:"query"` - LibraryId string `json:"libraryId"` - Stream bool `json:"stream"` -} - -type AIProxyLibraryError struct { - ErrCode int `json:"errCode"` - Message string `json:"message"` -} - -type AIProxyLibraryDocument struct { - Title string `json:"title"` - URL string `json:"url"` -} - -type AIProxyLibraryResponse struct { - Success bool `json:"success"` - Answer string `json:"answer"` - Documents []AIProxyLibraryDocument `json:"documents"` - AIProxyLibraryError -} - -type AIProxyLibraryStreamResponse struct { - Content string `json:"content"` - Finish bool `json:"finish"` - Model string `json:"model"` - Documents []AIProxyLibraryDocument `json:"documents"` -} - -func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { - query := "" - if len(request.Messages) != 0 { - query = string(request.Messages[len(request.Messages)-1].Content) - } - return &AIProxyLibraryRequest{ - Model: request.Model, - Stream: request.Stream, - Query: query, - } -} - -func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { - if len(documents) == 0 { - return "" - } - content := "\n\n参考文档:\n" - for i, document := range documents { - content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL) - } - return content -} - -func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { - content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents)) - choice := OpenAITextResponseChoice{ - Index: 0, - Message: Message{ - Role: "assistant", - Content: content, - }, - FinishReason: "stop", - } - fullTextResponse := OpenAITextResponse{ - Id: common.GetUUID(), - Object: "chat.completion", - Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, - } - return &fullTextResponse -} - -func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = aiProxyDocuments2Markdown(documents) - choice.FinishReason = &stopFinishReason - return &ChatCompletionsStreamResponse{ - Id: common.GetUUID(), - Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: "", - Choices: []ChatCompletionsStreamResponseChoice{choice}, - } -} - -func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice - choice.Delta.Content = response.Content - return &ChatCompletionsStreamResponse{ - Id: common.GetUUID(), - Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: response.Model, - Choices: []ChatCompletionsStreamResponseChoice{choice}, - } -} - -func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 { // ignore blank line or wrong format - continue - } - if data[:5] != "data:" { - continue - } - data = data[5:] - dataChan <- data - } - stopChan <- true - }() - setEventStreamHeaders(c) - var documents []AIProxyLibraryDocument - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - var AIProxyLibraryResponse AIProxyLibraryStreamResponse - err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - if len(AIProxyLibraryResponse.Documents) != 0 { - documents = AIProxyLibraryResponse.Documents - } - response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - response := documentsAIProxyLibrary(documents) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - err := resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - return nil, &usage -} - -func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var AIProxyLibraryResponse AIProxyLibraryResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if AIProxyLibraryResponse.ErrCode != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: AIProxyLibraryResponse.Message, - Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), - Code: AIProxyLibraryResponse.ErrCode, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage -} diff --git a/controller/relay-text.go b/controller/relay-text.go deleted file mode 100644 index ac64b66..0000000 --- a/controller/relay-text.go +++ /dev/null @@ -1,752 +0,0 @@ -package controller - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "one-api/common" - "one-api/model" - "strings" - "time" - - "github.com/gin-gonic/gin" -) - -const ( - APITypeOpenAI = iota - APITypeClaude - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent - APITypeGemini -) - -var httpClient *http.Client -var impatientHTTPClient *http.Client - -func init() { - if common.RelayTimeout == 0 { - httpClient = &http.Client{} - } else { - httpClient = &http.Client{ - Timeout: time.Duration(common.RelayTimeout) * time.Second, - } - } - - impatientHTTPClient = &http.Client{ - Timeout: 5 * time.Second, - } -} - -func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - group := c.GetString("group") - tokenUnlimited := c.GetBool("token_unlimited_quota") - startTime := time.Now() - var textRequest GeneralOpenAIRequest - - err := common.UnmarshalBodyReusable(c, &textRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - if relayMode == RelayModeModerations && textRequest.Model == "" { - textRequest.Model = "text-moderation-latest" - } - if relayMode == RelayModeEmbeddings && textRequest.Model == "" { - textRequest.Model = c.Param("model") - } - // request validation - if textRequest.Model == "" { - return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) - } - switch relayMode { - case RelayModeCompletions: - if textRequest.Prompt == "" { - return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeChatCompletions: - if textRequest.Messages == nil || len(textRequest.Messages) == 0 { - return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeEmbeddings: - case RelayModeModerations: - if textRequest.Input == "" { - return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeEdits: - if textRequest.Instruction == "" { - return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) - } - } - // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[textRequest.Model] != "" { - textRequest.Model = modelMap[textRequest.Model] - isModelMapped = true - } - } - apiType := APITypeOpenAI - switch channelType { - case common.ChannelTypeAnthropic: - apiType = APITypeClaude - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - case common.ChannelTypeGemini: - apiType = APITypeGemini - } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } - requestURL := strings.Split(requestURL, "?")[0] - requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) - baseURL = c.GetString("base_url") - task := strings.TrimPrefix(requestURL, "/v1/") - model_ := textRequest.Model - model_ = strings.Replace(model_, ".", "", -1) - // https://github.com/songquanpeng/one-api/issues/67 - model_ = strings.TrimSuffix(model_, "-0301") - model_ = strings.TrimSuffix(model_, "-0314") - model_ = strings.TrimSuffix(model_, "-0613") - requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) - } - case APITypeClaude: - fullRequestURL = "https://api.anthropic.com/v1/complete" - if baseURL != "" { - fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) - } - case APITypeBaidu: - switch textRequest.Model { - case "ERNIE-Bot": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" - case "ERNIE-Bot-turbo": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" - case "ERNIE-Bot-4": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" - case "BLOOMZ-7B": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" - case "Embedding-V1": - fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - var err error - if apiKey, err = getBaiduAccessToken(apiKey); err != nil { - return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) - } - fullRequestURL += "?access_token=" + apiKey - case APITypePaLM: - fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" - if baseURL != "" { - fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?key=" + apiKey - case APITypeGemini: - requestBaseURL := "https://generativelanguage.googleapis.com" - if baseURL != "" { - requestBaseURL = baseURL - } - version := "v1beta" - if c.GetString("api_version") != "" { - version = c.GetString("api_version") - } - action := "generateContent" - if textRequest.Stream { - action = "streamGenerateContent" - } - fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - fullRequestURL += "?key=" + apiKey - //log.Println(fullRequestURL) - - case APITypeZhipu: - method := "invoke" - if textRequest.Stream { - method = "sse-invoke" - } - fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) - case APITypeAli: - fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" - if relayMode == RelayModeEmbeddings { - fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" - } - case APITypeTencent: - fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" - case APITypeAIProxyLibrary: - fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) - } - var promptTokens int - var completionTokens int - switch relayMode { - case RelayModeChatCompletions: - promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model) - if err != nil { - return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) - } - case RelayModeCompletions: - promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) - case RelayModeModerations: - promptTokens = countTokenInput(textRequest.Input, textRequest.Model) - } - modelPrice := common.GetModelPrice(textRequest.Model, false) - groupRatio := common.GetGroupRatio(group) - - var preConsumedQuota int - var ratio float64 - var modelRatio float64 - if modelPrice == -1 { - preConsumedTokens := common.PreConsumedQuota - if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + int(textRequest.MaxTokens) - } - modelRatio = common.GetModelRatio(textRequest.Model) - ratio = modelRatio * groupRatio - preConsumedQuota = int(float64(preConsumedTokens) * ratio) - } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) - } - - userQuota, err := model.CacheGetUserQuota(userId) - if err != nil { - return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - } - if userQuota < 0 || userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // 用户额度充足,判断令牌额度是否充足 - if !tokenUnlimited { - // 非无限令牌,判断令牌额度是否充足 - tokenQuota := c.GetInt("token_quota") - if tokenQuota > 100*preConsumedQuota { - // 令牌额度充足,信任令牌 - preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", userId, userQuota, tokenId, tokenQuota)) - } - } else { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) - } - } - if preConsumedQuota > 0 { - userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } - } - var requestBody io.Reader - if isModelMapped { - jsonStr, err := json.Marshal(textRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } - switch apiType { - case APITypeClaude: - claudeRequest := requestOpenAI2Claude(textRequest) - jsonStr, err := json.Marshal(claudeRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeBaidu: - var jsonData []byte - var err error - switch relayMode { - case RelayModeEmbeddings: - baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) - jsonData, err = json.Marshal(baiduEmbeddingRequest) - default: - baiduRequest := requestOpenAI2Baidu(textRequest) - jsonData, err = json.Marshal(baiduRequest) - } - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonData) - case APITypePaLM: - palmRequest := requestOpenAI2PaLM(textRequest) - jsonStr, err := json.Marshal(palmRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeGemini: - geminiChatRequest := requestOpenAI2Gemini(textRequest) - jsonStr, err := json.Marshal(geminiChatRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeZhipu: - zhipuRequest := requestOpenAI2Zhipu(textRequest) - jsonStr, err := json.Marshal(zhipuRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeAli: - var jsonStr []byte - var err error - switch relayMode { - case RelayModeEmbeddings: - aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) - jsonStr, err = json.Marshal(aliEmbeddingRequest) - default: - aliRequest := requestOpenAI2Ali(textRequest) - jsonStr, err = json.Marshal(aliRequest) - } - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeTencent: - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - appId, secretId, secretKey, err := parseTencentConfig(apiKey) - if err != nil { - return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) - } - tencentRequest := requestOpenAI2Tencent(textRequest) - tencentRequest.AppId = appId - tencentRequest.SecretId = secretId - jsonStr, err := json.Marshal(tencentRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - sign := getTencentSign(*tencentRequest, secretKey) - c.Request.Header.Set("Authorization", sign) - requestBody = bytes.NewBuffer(jsonStr) - case APITypeAIProxyLibrary: - aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) - aiProxyLibraryRequest.LibraryId = c.GetString("library_id") - jsonStr, err := json.Marshal(aiProxyLibraryRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } - - var req *http.Request - var resp *http.Response - isStream := textRequest.Stream - - if apiType != APITypeXunfei { // cause xunfei use websocket - req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - // 设置GetBody函数,该函数返回一个新的io.ReadCloser,该io.ReadCloser返回与原始请求体相同的数据 - req.GetBody = func() (io.ReadCloser, error) { - return io.NopCloser(requestBody), nil - } - if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - req.Header.Set("api-key", apiKey) - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - if c.Request.Header.Get("OpenAI-Organization") != "" { - req.Header.Set("OpenAI-Organization", c.Request.Header.Get("OpenAI-Organization")) - } - if channelType == common.ChannelTypeOpenRouter { - req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") - req.Header.Set("X-Title", "One API") - } - } - case APITypeClaude: - req.Header.Set("x-api-key", apiKey) - anthropicVersion := c.Request.Header.Get("anthropic-version") - if anthropicVersion == "" { - anthropicVersion = "2023-06-01" - } - req.Header.Set("anthropic-version", anthropicVersion) - case APITypeZhipu: - token := getZhipuToken(apiKey) - req.Header.Set("Authorization", token) - case APITypeAli: - req.Header.Set("Authorization", "Bearer "+apiKey) - if textRequest.Stream { - req.Header.Set("X-DashScope-SSE", "enable") - } - case APITypeTencent: - req.Header.Set("Authorization", apiKey) - case APITypeGemini: - req.Header.Set("Content-Type", "application/json") - default: - req.Header.Set("Authorization", "Bearer "+apiKey) - } - if apiType != APITypeGemini { - // 设置公共头部... - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - if isStream && c.Request.Header.Get("Accept") == "" { - req.Header.Set("Accept", "text/event-stream") - } - } - //req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection")) - resp, err = httpClient.Do(req) - if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) - } - err = req.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - - if resp.StatusCode != http.StatusOK { - if preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - return relayErrorHandler(resp) - } - } - - var textResponse TextResponse - tokenName := c.GetString("token_name") - - defer func(ctx context.Context) { - // c.Writer.Flush() - go func() { - useTimeSeconds := time.Now().Unix() - startTime.Unix() - promptTokens = textResponse.Usage.PromptTokens - completionTokens = textResponse.Usage.CompletionTokens - - quota := 0 - if modelPrice == -1 { - completionRatio := common.GetCompletionRatio(textRequest.Model) - quota = promptTokens + int(float64(completionTokens)*completionRatio) - quota = int(float64(quota) * ratio) - if ratio != 0 && quota <= 0 { - quota = 1 - } - } else { - quota = int(modelPrice * common.QuotaPerUnit * groupRatio) - } - totalTokens := promptTokens + completionTokens - var logContent string - if modelPrice == -1 { - logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - } else { - logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) - } - - // record all the consume log even if quota is 0 - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - logContent += fmt.Sprintf("(有疑问请联系管理员)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", userId, channelId, tokenId, textRequest.Model, preConsumedQuota)) - } else { - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) - } - - logModel := textRequest.Model - if strings.HasPrefix(logModel, "gpt-4-gizmo") { - logModel = "gpt-4-gizmo-*" - logContent += fmt.Sprintf(",模型 %s", textRequest.Model) - } - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), isStream) - - //if quota != 0 { - // - //} - }() - }(c.Request.Context()) - switch apiType { - case APITypeOpenAI: - if isStream { - err, responseText := openaiStreamHandler(c, resp, relayMode) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeClaude: - if isStream { - err, responseText := claudeStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeBaidu: - if isStream { - err, usage := baiduStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - var err *OpenAIErrorWithStatusCode - var usage *Usage - switch relayMode { - case RelayModeEmbeddings: - err, usage = baiduEmbeddingHandler(c, resp) - default: - err, usage = baiduHandler(c, resp) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypePaLM: - if textRequest.Stream { // PaLM2 API does not support stream - err, responseText := palmStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeGemini: - if textRequest.Stream { - err, responseText := geminiChatStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeZhipu: - if isStream { - err, usage := zhipuStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - // zhipu's API does not return prompt tokens & completion tokens - textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens - return nil - } else { - err, usage := zhipuHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - // zhipu's API does not return prompt tokens & completion tokens - textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens - return nil - } - case APITypeAli: - if isStream { - err, usage := aliStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - var err *OpenAIErrorWithStatusCode - var usage *Usage - switch relayMode { - case RelayModeEmbeddings: - err, usage = aliEmbeddingHandler(c, resp) - default: - err, usage = aliHandler(c, resp) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeXunfei: - auth := c.Request.Header.Get("Authorization") - auth = strings.TrimPrefix(auth, "Bearer ") - splits := strings.Split(auth, "|") - if len(splits) != 3 { - return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) - } - var err *OpenAIErrorWithStatusCode - var usage *Usage - if isStream { - err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) - } else { - err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) - } - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - case APITypeAIProxyLibrary: - if isStream { - err, usage := aiProxyLibraryStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - err, usage := aiProxyLibraryHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeTencent: - if isStream { - err, responseText := tencentStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := tencentHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - default: - return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) - } -} diff --git a/controller/relay.go b/controller/relay.go index 9188b01..e22bb2d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,340 +1,34 @@ package controller import ( - "encoding/json" "fmt" + "github.com/gin-gonic/gin" "log" "net/http" "one-api/common" + "one-api/dto" + "one-api/relay" + "one-api/relay/constant" + relayconstant "one-api/relay/constant" + "one-api/service" "strconv" "strings" - - "github.com/gin-gonic/gin" ) -type Message struct { - Role string `json:"role"` - Content json.RawMessage `json:"content"` - Name *string `json:"name,omitempty"` - ToolCalls any `json:"tool_calls,omitempty"` - ToolCallId string `json:"tool_call_id,omitempty"` -} - -type MediaMessage struct { - Type string `json:"type"` - Text string `json:"text"` - ImageUrl any `json:"image_url,omitempty"` -} - -type MessageImageUrl struct { - Url string `json:"url"` - Detail string `json:"detail"` -} - -const ( - ContentTypeText = "text" - ContentTypeImageURL = "image_url" -) - -func (m Message) StringContent() string { - var stringContent string - if err := json.Unmarshal(m.Content, &stringContent); err == nil { - return stringContent - } - return string(m.Content) -} - -func (m Message) ParseContent() []MediaMessage { - var contentList []MediaMessage - var stringContent string - if err := json.Unmarshal(m.Content, &stringContent); err == nil { - contentList = append(contentList, MediaMessage{ - Type: ContentTypeText, - Text: stringContent, - }) - return contentList - } - var arrayContent []json.RawMessage - if err := json.Unmarshal(m.Content, &arrayContent); err == nil { - for _, contentItem := range arrayContent { - var contentMap map[string]any - if err := json.Unmarshal(contentItem, &contentMap); err != nil { - continue - } - switch contentMap["type"] { - case ContentTypeText: - if subStr, ok := contentMap["text"].(string); ok { - contentList = append(contentList, MediaMessage{ - Type: ContentTypeText, - Text: subStr, - }) - } - case ContentTypeImageURL: - if subObj, ok := contentMap["image_url"].(map[string]any); ok { - detail, ok := subObj["detail"] - if ok { - subObj["detail"] = detail.(string) - } else { - subObj["detail"] = "auto" - } - contentList = append(contentList, MediaMessage{ - Type: ContentTypeImageURL, - ImageUrl: MessageImageUrl{ - Url: subObj["url"].(string), - Detail: subObj["detail"].(string), - }, - }) - } - } - } - return contentList - } - - return nil -} - -const ( - RelayModeUnknown = iota - RelayModeChatCompletions - RelayModeCompletions - RelayModeEmbeddings - RelayModeModerations - RelayModeImagesGenerations - RelayModeEdits - RelayModeMidjourneyImagine - RelayModeMidjourneyDescribe - RelayModeMidjourneyBlend - RelayModeMidjourneyChange - RelayModeMidjourneySimpleChange - RelayModeMidjourneyNotify - RelayModeMidjourneyTaskFetch - RelayModeMidjourneyTaskFetchByCondition - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation -) - -// https://platform.openai.com/docs/api-reference/chat - -type ResponseFormat struct { - Type string `json:"type,omitempty"` -} - -type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` - LogProbs bool `json:"logprobs,omitempty"` - TopLogProbs int `json:"top_logprobs,omitempty"` -} - -func (r GeneralOpenAIRequest) ParseInput() []string { - if r.Input == nil { - return nil - } - var input []string - switch r.Input.(type) { - case string: - input = []string{r.Input.(string)} - case []any: - input = make([]string, 0, len(r.Input.([]any))) - for _, item := range r.Input.([]any) { - if str, ok := item.(string); ok { - input = append(input, str) - } - } - } - return input -} - -type AudioRequest struct { - Model string `json:"model"` - Voice string `json:"voice"` - Input string `json:"input"` -} - -type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - MaxTokens uint `json:"max_tokens"` -} - -type TextRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Prompt string `json:"prompt"` - MaxTokens uint `json:"max_tokens"` - //Stream bool `json:"stream"` -} - -type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` -} - -type AudioResponse struct { - Text string `json:"text,omitempty"` -} - -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type OpenAIError struct { - Message string `json:"message"` - Type string `json:"type"` - Param string `json:"param"` - Code any `json:"code"` -} - -type OpenAIErrorWithStatusCode struct { - OpenAIError - StatusCode int `json:"status_code"` -} - -type TextResponse struct { - Choices []OpenAITextResponseChoice `json:"choices"` - Usage `json:"usage"` - Error OpenAIError `json:"error"` -} - -type OpenAITextResponseChoice struct { - Index int `json:"index"` - Message `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type OpenAITextResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Choices []OpenAITextResponseChoice `json:"choices"` - Usage `json:"usage"` -} - -type OpenAIEmbeddingResponseItem struct { - Object string `json:"object"` - Index int `json:"index"` - Embedding []float64 `json:"embedding"` -} - -type OpenAIEmbeddingResponse struct { - Object string `json:"object"` - Data []OpenAIEmbeddingResponseItem `json:"data"` - Model string `json:"model"` - Usage `json:"usage"` -} - -type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - B64Json string `json:"b64_json"` - } -} - -type ChatCompletionsStreamResponseChoice struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` -} - -type ChatCompletionsStreamResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionsStreamResponseChoice `json:"choices"` -} - -type ChatCompletionsStreamResponseSimple struct { - Choices []ChatCompletionsStreamResponseChoice `json:"choices"` -} - -type CompletionsStreamResponse struct { - Choices []struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` -} - -type MidjourneyRequest struct { - Prompt string `json:"prompt"` - NotifyHook string `json:"notifyHook"` - Action string `json:"action"` - Index int `json:"index"` - State string `json:"state"` - TaskId string `json:"taskId"` - Base64Array []string `json:"base64Array"` - Content string `json:"content"` -} - -type MidjourneyResponse struct { - Code int `json:"code"` - Description string `json:"description"` - Properties interface{} `json:"properties"` - Result string `json:"result"` -} - func Relay(c *gin.Context) { - relayMode := RelayModeUnknown - if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { - relayMode = RelayModeChatCompletions - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { - relayMode = RelayModeCompletions - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - relayMode = RelayModeModerations - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - relayMode = RelayModeImagesGenerations - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { - relayMode = RelayModeEdits - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - relayMode = RelayModeAudioSpeech - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - relayMode = RelayModeAudioTranscription - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - relayMode = RelayModeAudioTranslation - } - var err *OpenAIErrorWithStatusCode + relayMode := constant.Path2RelayMode(c.Request.URL.Path) + var err *dto.OpenAIErrorWithStatusCode switch relayMode { - case RelayModeImagesGenerations: - err = relayImageHelper(c, relayMode) - case RelayModeAudioSpeech: + case relayconstant.RelayModeImagesGenerations: + err = relay.RelayImageHelper(c, relayMode) + case relayconstant.RelayModeAudioSpeech: fallthrough - case RelayModeAudioTranslation: + case relayconstant.RelayModeAudioTranslation: fallthrough - case RelayModeAudioTranscription: - err = relayAudioHelper(c, relayMode) + case relayconstant.RelayModeAudioTranscription: + err = relay.RelayAudioHelper(c, relayMode) default: - err = relayTextHelper(c, relayMode) + err = relay.TextHelper(c) } if err != nil { requestId := c.GetString(common.RequestIdKey) @@ -358,42 +52,42 @@ func Relay(c *gin.Context) { autoBan := c.GetBool("auto_ban") common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan { + if service.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") - disableChannel(channelId, channelName, err.Message) + service.DisableChannel(channelId, channelName, err.Message) } } } func RelayMidjourney(c *gin.Context) { - relayMode := RelayModeUnknown + relayMode := relayconstant.RelayModeUnknown if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") { - relayMode = RelayModeMidjourneyImagine + relayMode = relayconstant.RelayModeMidjourneyImagine } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") { - relayMode = RelayModeMidjourneyBlend + relayMode = relayconstant.RelayModeMidjourneyBlend } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") { - relayMode = RelayModeMidjourneyDescribe + relayMode = relayconstant.RelayModeMidjourneyDescribe } else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") { - relayMode = RelayModeMidjourneyNotify + relayMode = relayconstant.RelayModeMidjourneyNotify } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") { - relayMode = RelayModeMidjourneyChange + relayMode = relayconstant.RelayModeMidjourneyChange } else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") { - relayMode = RelayModeMidjourneyChange + relayMode = relayconstant.RelayModeMidjourneyChange } else if strings.HasSuffix(c.Request.URL.Path, "/fetch") { - relayMode = RelayModeMidjourneyTaskFetch + relayMode = relayconstant.RelayModeMidjourneyTaskFetch } else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") { - relayMode = RelayModeMidjourneyTaskFetchByCondition + relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition } - var err *MidjourneyResponse + var err *dto.MidjourneyResponse switch relayMode { - case RelayModeMidjourneyNotify: - err = relayMidjourneyNotify(c) - case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition: - err = relayMidjourneyTask(c, relayMode) + case relayconstant.RelayModeMidjourneyNotify: + err = relay.RelayMidjourneyNotify(c) + case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: + err = relay.RelayMidjourneyTask(c, relayMode) default: - err = relayMidjourneySubmit(c, relayMode) + err = relay.RelayMidjourneySubmit(c, relayMode) } //err = relayMidjourneySubmit(c, relayMode) log.Println(err) @@ -425,7 +119,7 @@ func RelayMidjourney(c *gin.Context) { } func RelayNotImplemented(c *gin.Context) { - err := OpenAIError{ + err := dto.OpenAIError{ Message: "API not implemented", Type: "new_api_error", Param: "", @@ -437,7 +131,7 @@ func RelayNotImplemented(c *gin.Context) { } func RelayNotFound(c *gin.Context) { - err := OpenAIError{ + err := dto.OpenAIError{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Type: "invalid_request_error", Param: "", diff --git a/dto/error.go b/dto/error.go new file mode 100644 index 0000000..bfb3376 --- /dev/null +++ b/dto/error.go @@ -0,0 +1,13 @@ +package dto + +type OpenAIError struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param"` + Code any `json:"code"` +} + +type OpenAIErrorWithStatusCode struct { + OpenAIError + StatusCode int `json:"status_code"` +} diff --git a/dto/request.go b/dto/request.go new file mode 100644 index 0000000..fa8770a --- /dev/null +++ b/dto/request.go @@ -0,0 +1,137 @@ +package dto + +import "encoding/json" + +type ResponseFormat struct { + Type string `json:"type,omitempty"` +} + +type GeneralOpenAIRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` + LogProbs bool `json:"logprobs,omitempty"` + TopLogProbs int `json:"top_logprobs,omitempty"` +} + +func (r GeneralOpenAIRequest) ParseInput() []string { + if r.Input == nil { + return nil + } + var input []string + switch r.Input.(type) { + case string: + input = []string{r.Input.(string)} + case []any: + input = make([]string, 0, len(r.Input.([]any))) + for _, item := range r.Input.([]any) { + if str, ok := item.(string); ok { + input = append(input, str) + } + } + } + return input +} + +type Message struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Name *string `json:"name,omitempty"` + ToolCalls any `json:"tool_calls,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` +} + +type MediaMessage struct { + Type string `json:"type"` + Text string `json:"text"` + ImageUrl any `json:"image_url,omitempty"` +} + +type MessageImageUrl struct { + Url string `json:"url"` + Detail string `json:"detail"` +} + +const ( + ContentTypeText = "text" + ContentTypeImageURL = "image_url" +) + +func (m Message) StringContent() string { + var stringContent string + if err := json.Unmarshal(m.Content, &stringContent); err == nil { + return stringContent + } + return string(m.Content) +} + +func (m Message) ParseContent() []MediaMessage { + var contentList []MediaMessage + var stringContent string + if err := json.Unmarshal(m.Content, &stringContent); err == nil { + contentList = append(contentList, MediaMessage{ + Type: ContentTypeText, + Text: stringContent, + }) + return contentList + } + var arrayContent []json.RawMessage + if err := json.Unmarshal(m.Content, &arrayContent); err == nil { + for _, contentItem := range arrayContent { + var contentMap map[string]any + if err := json.Unmarshal(contentItem, &contentMap); err != nil { + continue + } + switch contentMap["type"] { + case ContentTypeText: + if subStr, ok := contentMap["text"].(string); ok { + contentList = append(contentList, MediaMessage{ + Type: ContentTypeText, + Text: subStr, + }) + } + case ContentTypeImageURL: + if subObj, ok := contentMap["image_url"].(map[string]any); ok { + detail, ok := subObj["detail"] + if ok { + subObj["detail"] = detail.(string) + } else { + subObj["detail"] = "auto" + } + contentList = append(contentList, MediaMessage{ + Type: ContentTypeImageURL, + ImageUrl: MessageImageUrl{ + Url: subObj["url"].(string), + Detail: subObj["detail"].(string), + }, + }) + } + } + } + return contentList + } + + return nil +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} diff --git a/dto/response.go b/dto/response.go new file mode 100644 index 0000000..620c083 --- /dev/null +++ b/dto/response.go @@ -0,0 +1,86 @@ +package dto + +type TextResponse struct { + Choices []OpenAITextResponseChoice `json:"choices"` + Usage `json:"usage"` + Error OpenAIError `json:"error"` +} + +type OpenAITextResponseChoice struct { + Index int `json:"index"` + Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type OpenAITextResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Choices []OpenAITextResponseChoice `json:"choices"` + Usage `json:"usage"` +} + +type OpenAIEmbeddingResponseItem struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type OpenAIEmbeddingResponse struct { + Object string `json:"object"` + Data []OpenAIEmbeddingResponseItem `json:"data"` + Model string `json:"model"` + Usage `json:"usage"` +} + +type ImageResponse struct { + Created int `json:"created"` + Data []struct { + Url string `json:"url"` + B64Json string `json:"b64_json"` + } +} + +type ChatCompletionsStreamResponseChoice struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +type ChatCompletionsStreamResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionsStreamResponseChoice `json:"choices"` +} + +type ChatCompletionsStreamResponseSimple struct { + Choices []ChatCompletionsStreamResponseChoice `json:"choices"` +} + +type CompletionsStreamResponse struct { + Choices []struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +type MidjourneyRequest struct { + Prompt string `json:"prompt"` + NotifyHook string `json:"notifyHook"` + Action string `json:"action"` + Index int `json:"index"` + State string `json:"state"` + TaskId string `json:"taskId"` + Base64Array []string `json:"base64Array"` + Content string `json:"content"` +} + +type MidjourneyResponse struct { + Code int `json:"code"` + Description string `json:"description"` + Properties interface{} `json:"properties"` + Result string `json:"result"` +} diff --git a/main.go b/main.go index 2ba7449..94496a4 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( "one-api/controller" "one-api/middleware" "one-api/model" + "one-api/relay/common" "one-api/router" "os" "strconv" @@ -105,7 +106,7 @@ func main() { common.SysLog("pprof enabled") } - controller.InitTokenEncoders() + common.InitTokenEncoders() // Initialize HTTP server server := gin.New() diff --git a/middleware/distributor.go b/middleware/distributor.go index e105fb8..1ca43dd 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -129,15 +129,18 @@ func Distribute() func(c *gin.Context) { c.Set("model_mapping", channel.GetModelMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) + // TODO: api_version统一 switch channel.Type { case common.ChannelTypeAzure: c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) - case common.ChannelTypeAIProxyLibrary: - c.Set("library_id", channel.Other) + //case common.ChannelTypeAIProxyLibrary: + // c.Set("library_id", channel.Other) case common.ChannelTypeGemini: c.Set("api_version", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) } c.Next() } diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go new file mode 100644 index 0000000..ee696fd --- /dev/null +++ b/relay/channel/adapter.go @@ -0,0 +1,57 @@ +package channel + +import ( + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel/ali" + "one-api/relay/channel/baidu" + "one-api/relay/channel/claude" + "one-api/relay/channel/gemini" + "one-api/relay/channel/openai" + "one-api/relay/channel/palm" + "one-api/relay/channel/tencent" + "one-api/relay/channel/xunfei" + "one-api/relay/channel/zhipu" + relaycommon "one-api/relay/common" + "one-api/relay/constant" +) + +type Adaptor interface { + // Init IsStream bool + Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) + GetRequestURL(info *relaycommon.RelayInfo) (string, error) + SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error + ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) + DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) + GetModelList() []string + GetChannelName() string +} + +func GetAdaptor(apiType int) Adaptor { + switch apiType { + //case constant.APITypeAIProxyLibrary: + // return &aiproxy.Adaptor{} + case constant.APITypeAli: + return &ali.Adaptor{} + case constant.APITypeAnthropic: + return &claude.Adaptor{} + case constant.APITypeBaidu: + return &baidu.Adaptor{} + case constant.APITypeGemini: + return &gemini.Adaptor{} + case constant.APITypeOpenAI: + return &openai.Adaptor{} + case constant.APITypePaLM: + return &palm.Adaptor{} + case constant.APITypeTencent: + return &tencent.Adaptor{} + case constant.APITypeXunfei: + return &xunfei.Adaptor{} + case constant.APITypeZhipu: + return &zhipu.Adaptor{} + } + return nil +} diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go new file mode 100644 index 0000000..b79299a --- /dev/null +++ b/relay/channel/ali/adaptor.go @@ -0,0 +1,80 @@ +package ali + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/relay/constant" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { + +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl) + if info.RelayMode == constant.RelayModeEmbeddings { + fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) + } + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + relaychannel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", "Bearer "+info.ApiKey) + if info.IsStream { + req.Header.Set("X-DashScope-SSE", "enable") + } + if c.GetString("plugin") != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + } + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request) + return baiduEmbeddingRequest, nil + default: + baiduRequest := requestOpenAI2Ali(*request) + return baiduRequest, nil + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return relaychannel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = aliStreamHandler(c, resp) + } else { + switch info.RelayMode { + case constant.RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + default: + err, usage = aliHandler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/ali/constants.go b/relay/channel/ali/constants.go new file mode 100644 index 0000000..6f6658c --- /dev/null +++ b/relay/channel/ali/constants.go @@ -0,0 +1,8 @@ +package ali + +var ModelList = []string{ + "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", + "text-embedding-v1", +} + +var ChannelName = "ali" diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go new file mode 100644 index 0000000..bc63f63 --- /dev/null +++ b/relay/channel/ali/dto.go @@ -0,0 +1,70 @@ +package ali + +type AliMessage struct { + User string `json:"user"` + Bot string `json:"bot"` +} + +type AliInput struct { + Prompt string `json:"prompt"` + History []AliMessage `json:"history"` +} + +type AliParameters struct { + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` +} + +type AliChatRequest struct { + Model string `json:"model"` + Input AliInput `json:"input"` + Parameters AliParameters `json:"parameters,omitempty"` +} + +type AliEmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type AliEmbedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type AliEmbeddingResponse struct { + Output struct { + Embeddings []AliEmbedding `json:"embeddings"` + } `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + +type AliError struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type AliUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type AliOutput struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +type AliChatResponse struct { + Output AliOutput `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} diff --git a/controller/relay-ali.go b/relay/channel/ali/relay-ali.go similarity index 61% rename from controller/relay-ali.go rename to relay/channel/ali/relay-ali.go index 93839d9..2f39087 100644 --- a/controller/relay-ali.go +++ b/relay/channel/ali/relay-ali.go @@ -1,4 +1,4 @@ -package controller +package ali import ( "bufio" @@ -7,81 +7,14 @@ import ( "io" "net/http" "one-api/common" + "one-api/dto" + "one-api/service" "strings" ) // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r -type AliMessage struct { - User string `json:"user"` - Bot string `json:"bot"` -} - -type AliInput struct { - Prompt string `json:"prompt"` - History []AliMessage `json:"history"` -} - -type AliParameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` -} - -type AliChatRequest struct { - Model string `json:"model"` - Input AliInput `json:"input"` - Parameters AliParameters `json:"parameters,omitempty"` -} - -type AliEmbeddingRequest struct { - Model string `json:"model"` - Input struct { - Texts []string `json:"texts"` - } `json:"input"` - Parameters *struct { - TextType string `json:"text_type,omitempty"` - } `json:"parameters,omitempty"` -} - -type AliEmbedding struct { - Embedding []float64 `json:"embedding"` - TextIndex int `json:"text_index"` -} - -type AliEmbeddingResponse struct { - Output struct { - Embeddings []AliEmbedding `json:"embeddings"` - } `json:"output"` - Usage AliUsage `json:"usage"` - AliError -} - -type AliError struct { - Code string `json:"code"` - Message string `json:"message"` - RequestId string `json:"request_id"` -} - -type AliUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type AliOutput struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` -} - -type AliChatResponse struct { - Output AliOutput `json:"output"` - Usage AliUsage `json:"usage"` - AliError -} - -func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { +func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) prompt := "" for i := 0; i < len(request.Messages); i++ { @@ -119,7 +52,7 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { } } -func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest { +func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest { return &AliEmbeddingRequest{ Model: "text-embedding-v1", Input: struct { @@ -130,21 +63,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque } } -func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { +func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var aliResponse AliEmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&aliResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } if aliResponse.Code != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &dto.OpenAIErrorWithStatusCode{ + OpenAIError: dto.OpenAIError{ Message: aliResponse.Message, Type: aliResponse.Code, Param: aliResponse.RequestId, @@ -157,7 +90,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -165,16 +98,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS return nil, &fullTextResponse.Usage } -func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse { - openAIEmbeddingResponse := OpenAIEmbeddingResponse{ +func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse { + openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{ Object: "list", - Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), + Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)), Model: "text-embedding-v1", - Usage: Usage{TotalTokens: response.Usage.TotalTokens}, + Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens}, } for _, item := range response.Output.Embeddings { - openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{ Object: `embedding`, Index: item.TextIndex, Embedding: item.Embedding, @@ -183,22 +116,22 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin return &openAIEmbeddingResponse } -func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { +func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse { content, _ := json.Marshal(response.Output.Text) - choice := OpenAITextResponseChoice{ + choice := dto.OpenAITextResponseChoice{ Index: 0, - Message: Message{ + Message: dto.Message{ Role: "assistant", Content: content, }, FinishReason: response.Output.FinishReason, } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := dto.OpenAITextResponse{ Id: response.RequestId, Object: "chat.completion", Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, - Usage: Usage{ + Choices: []dto.OpenAITextResponseChoice{choice}, + Usage: dto.Usage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, @@ -207,25 +140,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse { + var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = aliResponse.Output.Text if aliResponse.Output.FinishReason != "null" { finishReason := aliResponse.Output.FinishReason choice.FinishReason = &finishReason } - response := ChatCompletionsStreamResponse{ + response := dto.ChatCompletionsStreamResponse{ Id: aliResponse.RequestId, Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "ernie-bot", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, } return &response } -func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage +func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var usage dto.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -255,7 +188,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat } stopChan <- true }() - setEventStreamHeaders(c) + service.SetEventStreamHeaders(c) lastResponseText := "" c.Stream(func(w io.Writer) bool { select { @@ -288,28 +221,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &usage } -func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { +func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var aliResponse AliChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &aliResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if aliResponse.Code != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &dto.OpenAIErrorWithStatusCode{ + OpenAIError: dto.OpenAIError{ Message: aliResponse.Message, Type: aliResponse.Code, Param: aliResponse.RequestId, @@ -321,7 +254,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode fullTextResponse := responseAli2OpenAI(&aliResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go new file mode 100644 index 0000000..b0ef212 --- /dev/null +++ b/relay/channel/api_request.go @@ -0,0 +1,52 @@ +package channel + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + relaycommon "one-api/relay/common" + "one-api/service" +) + +func SetupApiRequestHeader(info *relaycommon.RelayInfo, c *gin.Context, req *http.Request) { + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + if info.IsStream && c.Request.Header.Get("Accept") == "" { + req.Header.Set("Accept", "text/event-stream") + } +} + +func DoApiRequest(a Adaptor, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + fullRequestURL, err := a.GetRequestURL(info) + if err != nil { + return nil, fmt.Errorf("get request url failed: %w", err) + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + err = a.SetupRequestHeader(c, req, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(c, req) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + +func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { + resp, err := service.GetHttpClient().Do(req) + if err != nil { + return nil, err + } + if resp == nil { + return nil, errors.New("resp is nil") + } + _ = req.Body.Close() + _ = c.Request.Body.Close() + return resp, nil +} diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go new file mode 100644 index 0000000..a07fccd --- /dev/null +++ b/relay/channel/baidu/adaptor.go @@ -0,0 +1,92 @@ +package baidu + +import ( + "errors" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/relay/constant" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { + +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + var fullRequestURL string + switch info.UpstreamModelName { + case "ERNIE-Bot-4": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" + case "ERNIE-Bot-8K": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k" + case "ERNIE-Bot": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" + case "ERNIE-Speed": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" + case "ERNIE-Bot-turbo": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" + case "BLOOMZ-7B": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" + case "Embedding-V1": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" + } + var accessToken string + var err error + if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil { + return "", err + } + fullRequestURL += "?access_token=" + accessToken + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + relaychannel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case constant.RelayModeEmbeddings: + baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request) + return baiduEmbeddingRequest, nil + default: + baiduRequest := requestOpenAI2Baidu(*request) + return baiduRequest, nil + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return relaychannel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = baiduStreamHandler(c, resp) + } else { + switch info.RelayMode { + case constant.RelayModeEmbeddings: + err, usage = baiduEmbeddingHandler(c, resp) + default: + err, usage = baiduHandler(c, resp) + } + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go new file mode 100644 index 0000000..a0162bb --- /dev/null +++ b/relay/channel/baidu/constants.go @@ -0,0 +1,12 @@ +package baidu + +var ModelList = []string{ + "ERNIE-Bot-4", + "ERNIE-Bot-8K", + "ERNIE-Bot", + "ERNIE-Speed", + "ERNIE-Bot-turbo", + "Embedding-V1", +} + +var ChannelName = "baidu" diff --git a/relay/channel/baidu/dto.go b/relay/channel/baidu/dto.go new file mode 100644 index 0000000..2c37698 --- /dev/null +++ b/relay/channel/baidu/dto.go @@ -0,0 +1,71 @@ +package baidu + +import ( + "one-api/dto" + "time" +) + +type BaiduMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type BaiduChatRequest struct { + Messages []BaiduMessage `json:"messages"` + Stream bool `json:"stream"` + UserId string `json:"user_id,omitempty"` +} + +type Error struct { + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` +} + +type BaiduChatResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage dto.Usage `json:"usage"` + Error +} + +type BaiduChatStreamResponse struct { + BaiduChatResponse + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` +} + +type BaiduEmbeddingRequest struct { + Input []string `json:"input"` +} + +type BaiduEmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type BaiduEmbeddingResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Data []BaiduEmbeddingData `json:"data"` + Usage dto.Usage `json:"usage"` + Error +} + +type BaiduAccessToken struct { + AccessToken string `json:"access_token"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"-"` +} + +type BaiduTokenResponse struct { + ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` +} diff --git a/controller/relay-baidu.go b/relay/channel/baidu/relay-baidu.go similarity index 63% rename from controller/relay-baidu.go rename to relay/channel/baidu/relay-baidu.go index 3dd9ba6..fb58550 100644 --- a/controller/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -1,4 +1,4 @@ -package controller +package baidu import ( "bufio" @@ -9,6 +9,9 @@ import ( "io" "net/http" "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" "strings" "sync" "time" @@ -16,74 +19,9 @@ import ( // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 -type BaiduTokenResponse struct { - ExpiresIn int `json:"expires_in"` - AccessToken string `json:"access_token"` -} - -type BaiduMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type BaiduChatRequest struct { - Messages []BaiduMessage `json:"messages"` - Stream bool `json:"stream"` - UserId string `json:"user_id,omitempty"` -} - -type BaiduError struct { - ErrorCode int `json:"error_code"` - ErrorMsg string `json:"error_msg"` -} - -type BaiduChatResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Result string `json:"result"` - IsTruncated bool `json:"is_truncated"` - NeedClearHistory bool `json:"need_clear_history"` - Usage Usage `json:"usage"` - BaiduError -} - -type BaiduChatStreamResponse struct { - BaiduChatResponse - SentenceId int `json:"sentence_id"` - IsEnd bool `json:"is_end"` -} - -type BaiduEmbeddingRequest struct { - Input []string `json:"input"` -} - -type BaiduEmbeddingData struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` - Index int `json:"index"` -} - -type BaiduEmbeddingResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Data []BaiduEmbeddingData `json:"data"` - Usage Usage `json:"usage"` - BaiduError -} - -type BaiduAccessToken struct { - AccessToken string `json:"access_token"` - Error string `json:"error,omitempty"` - ErrorDescription string `json:"error_description,omitempty"` - ExpiresIn int64 `json:"expires_in,omitempty"` - ExpiresAt time.Time `json:"-"` -} - var baiduTokenStore sync.Map -func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { +func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest { messages := make([]BaiduMessage, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { @@ -108,57 +46,57 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } } -func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { +func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse { content, _ := json.Marshal(response.Result) - choice := OpenAITextResponseChoice{ + choice := dto.OpenAITextResponseChoice{ Index: 0, - Message: Message{ + Message: dto.Message{ Role: "assistant", Content: content, }, FinishReason: "stop", } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := dto.OpenAITextResponse{ Id: response.Id, Object: "chat.completion", Created: response.Created, - Choices: []OpenAITextResponseChoice{choice}, + Choices: []dto.OpenAITextResponseChoice{choice}, Usage: response.Usage, } return &fullTextResponse } -func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse { + var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = baiduResponse.Result if baiduResponse.IsEnd { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &relaycommon.StopFinishReason } - response := ChatCompletionsStreamResponse{ + response := dto.ChatCompletionsStreamResponse{ Id: baiduResponse.Id, Object: "chat.completion.chunk", Created: baiduResponse.Created, Model: "ernie-bot", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, } return &response } -func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { +func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest { return &BaiduEmbeddingRequest{ Input: request.ParseInput(), } } -func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { - openAIEmbeddingResponse := OpenAIEmbeddingResponse{ +func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse { + openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{ Object: "list", - Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), + Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)), Model: "baidu-embedding", Usage: response.Usage, } for _, item := range response.Data { - openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{ Object: item.Object, Index: item.Index, Embedding: item.Embedding, @@ -167,8 +105,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe return &openAIEmbeddingResponse } -func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage Usage +func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var usage dto.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -195,7 +133,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt } stopChan <- true }() - setEventStreamHeaders(c) + service.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -225,28 +163,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &usage } -func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { +func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var baiduResponse BaiduChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if baiduResponse.ErrorMsg != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &dto.OpenAIErrorWithStatusCode{ + OpenAIError: dto.OpenAIError{ Message: baiduResponse.ErrorMsg, Type: "baidu_error", Param: "", @@ -258,7 +196,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo fullTextResponse := responseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -266,23 +204,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo return nil, &fullTextResponse.Usage } -func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { +func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var baiduResponse BaiduEmbeddingResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if baiduResponse.ErrorMsg != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &dto.OpenAIErrorWithStatusCode{ + OpenAIError: dto.OpenAIError{ Message: baiduResponse.ErrorMsg, Type: "baidu_error", Param: "", @@ -294,7 +232,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) @@ -337,7 +275,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { } req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") - res, err := impatientHTTPClient.Do(req) + res, err := service.GetImpatientHttpClient().Do(req) if err != nil { return nil, err } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go new file mode 100644 index 0000000..130024b --- /dev/null +++ b/relay/channel/claude/adaptor.go @@ -0,0 +1,65 @@ +package claude + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { + +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + relaychannel.SetupApiRequestHeader(info, c, req) + req.Header.Set("x-api-key", info.ApiKey) + anthropicVersion := c.Request.Header.Get("anthropic-version") + if anthropicVersion == "" { + anthropicVersion = "2023-06-01" + } + req.Header.Set("anthropic-version", anthropicVersion) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return relaychannel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + var responseText string + err, responseText = claudeStreamHandler(c, resp) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + } else { + err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/claude/constants.go b/relay/channel/claude/constants.go new file mode 100644 index 0000000..1015fb8 --- /dev/null +++ b/relay/channel/claude/constants.go @@ -0,0 +1,7 @@ +package claude + +var ModelList = []string{ + "claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", +} + +var ChannelName = "claude" diff --git a/relay/channel/claude/dto.go b/relay/channel/claude/dto.go new file mode 100644 index 0000000..2231e5f --- /dev/null +++ b/relay/channel/claude/dto.go @@ -0,0 +1,29 @@ +package claude + +type ClaudeMetadata struct { + UserId string `json:"user_id"` +} + +type ClaudeRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + MaxTokensToSample uint `json:"max_tokens_to_sample"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + //ClaudeMetadata `json:"metadata,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type ClaudeError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type ClaudeResponse struct { + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Error ClaudeError `json:"error"` +} diff --git a/controller/relay-claude.go b/relay/channel/claude/relay-claude.go similarity index 68% rename from controller/relay-claude.go rename to relay/channel/claude/relay-claude.go index ee7abc7..186564f 100644 --- a/controller/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -1,4 +1,4 @@ -package controller +package claude import ( "bufio" @@ -8,37 +8,11 @@ import ( "io" "net/http" "one-api/common" + "one-api/dto" + "one-api/service" "strings" ) -type ClaudeMetadata struct { - UserId string `json:"user_id"` -} - -type ClaudeRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - MaxTokensToSample uint `json:"max_tokens_to_sample"` - StopSequences []string `json:"stop_sequences,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` -} - -type ClaudeError struct { - Type string `json:"type"` - Message string `json:"message"` -} - -type ClaudeResponse struct { - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error ClaudeError `json:"error"` -} - func stopReasonClaude2OpenAI(reason string) string { switch reason { case "stop_sequence": @@ -50,7 +24,7 @@ func stopReasonClaude2OpenAI(reason string) string { } } -func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { +func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { claudeRequest := ClaudeRequest{ Model: textRequest.Model, Prompt: "", @@ -78,41 +52,41 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { return &claudeRequest } -func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse { + var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = claudeResponse.Completion finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) if finishReason != "null" { choice.FinishReason = &finishReason } - var response ChatCompletionsStreamResponse + var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model - response.Choices = []ChatCompletionsStreamResponseChoice{choice} + response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice} return &response } -func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { +func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) - choice := OpenAITextResponseChoice{ + choice := dto.OpenAITextResponseChoice{ Index: 0, - Message: Message{ + Message: dto.Message{ Role: "assistant", Content: content, Name: nil, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := dto.OpenAITextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion", Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, + Choices: []dto.OpenAITextResponseChoice{choice}, } return &fullTextResponse } -func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { responseText := "" responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createdTime := common.GetTimestamp() @@ -142,7 +116,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS } stopChan <- true }() - setEventStreamHeaders(c) + service.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -172,28 +146,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } var claudeResponse ClaudeResponse err = json.Unmarshal(responseBody, &claudeResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if claudeResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &dto.OpenAIErrorWithStatusCode{ + OpenAIError: dto.OpenAIError{ Message: claudeResponse.Error.Message, Type: claudeResponse.Error.Type, Param: "", @@ -203,8 +177,8 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model }, nil } fullTextResponse := responseClaude2OpenAI(&claudeResponse) - completionTokens := countTokenText(claudeResponse.Completion, model) - usage := Usage{ + completionTokens := service.CountTokenText(claudeResponse.Completion, model) + usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, @@ -212,7 +186,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go new file mode 100644 index 0000000..5a200eb --- /dev/null +++ b/relay/channel/gemini/adaptor.go @@ -0,0 +1,64 @@ +package gemini + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + version := "v1" + action := "generateContent" + if info.IsStream { + action = "streamGenerateContent" + } + return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + relaychannel.SetupApiRequestHeader(info, c, req) + req.Header.Set("x-goog-api-key", info.ApiKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return CovertGemini2OpenAI(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return relaychannel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + var responseText string + err, responseText = geminiChatStreamHandler(c, resp) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + } else { + err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go new file mode 100644 index 0000000..24e85a8 --- /dev/null +++ b/relay/channel/gemini/constant.go @@ -0,0 +1,12 @@ +package gemini + +const ( + GeminiVisionMaxImageNum = 16 +) + +var ModelList = []string{ + "gemini-pro", + "gemini-pro-vision", +} + +var ChannelName = "google gemini" diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go new file mode 100644 index 0000000..a581c68 --- /dev/null +++ b/relay/channel/gemini/dto.go @@ -0,0 +1,62 @@ +package gemini + +type GeminiChatRequest struct { + Contents []GeminiChatContent `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []GeminiChatTools `json:"tools,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type GeminiChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +type GeminiChatCandidate struct { + Content GeminiChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +type GeminiChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type GeminiChatPromptFeedback struct { + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +type GeminiChatResponse struct { + Candidates []GeminiChatCandidate `json:"candidates"` + PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` +} diff --git a/controller/relay-gemini.go b/relay/channel/gemini/relay-gemini.go similarity index 62% rename from controller/relay-gemini.go rename to relay/channel/gemini/relay-gemini.go index 7235158..83ede7f 100644 --- a/controller/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1,4 +1,4 @@ -package controller +package gemini import ( "bufio" @@ -7,57 +7,16 @@ import ( "io" "net/http" "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" "strings" "github.com/gin-gonic/gin" ) -const ( - GeminiVisionMaxImageNum = 16 -) - -type GeminiChatRequest struct { - Contents []GeminiChatContent `json:"contents"` - SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` - GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` - Tools []GeminiChatTools `json:"tools,omitempty"` -} - -type GeminiInlineData struct { - MimeType string `json:"mimeType"` - Data string `json:"data"` -} - -type GeminiPart struct { - Text string `json:"text,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` -} - -type GeminiChatContent struct { - Role string `json:"role,omitempty"` - Parts []GeminiPart `json:"parts"` -} - -type GeminiChatSafetySettings struct { - Category string `json:"category"` - Threshold string `json:"threshold"` -} - -type GeminiChatTools struct { - FunctionDeclarations any `json:"functionDeclarations,omitempty"` -} - -type GeminiChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` -} - // Setting safety to the lowest possible values since Gemini is already powerless enough -func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { +func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest { geminiRequest := GeminiChatRequest{ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), SafetySettings: []GeminiChatSafetySettings{ @@ -106,16 +65,16 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { imageNum := 0 for _, part := range openaiContent { - if part.Type == ContentTypeText { + if part.Type == dto.ContentTypeText { parts = append(parts, GeminiPart{ Text: part.Text, }) - } else if part.Type == ContentTypeImageURL { + } else if part.Type == dto.ContentTypeImageURL { imageNum += 1 if imageNum > GeminiVisionMaxImageNum { continue } - mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url) + mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ MimeType: mimeType, @@ -154,11 +113,6 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { return &geminiRequest } -type GeminiChatResponse struct { - Candidates []GeminiChatCandidate `json:"candidates"` - PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` -} - func (g *GeminiChatResponse) GetResponseText() string { if g == nil { return "" @@ -169,38 +123,22 @@ func (g *GeminiChatResponse) GetResponseText() string { return "" } -type GeminiChatCandidate struct { - Content GeminiChatContent `json:"content"` - FinishReason string `json:"finishReason"` - Index int64 `json:"index"` - SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` -} - -type GeminiChatSafetyRating struct { - Category string `json:"category"` - Probability string `json:"probability"` -} - -type GeminiChatPromptFeedback struct { - SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` -} - -func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ +func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse { + fullTextResponse := dto.OpenAITextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion", Created: common.GetTimestamp(), - Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), } content, _ := json.Marshal("") for i, candidate := range response.Candidates { - choice := OpenAITextResponseChoice{ + choice := dto.OpenAITextResponseChoice{ Index: i, - Message: Message{ + Message: dto.Message{ Role: "assistant", Content: content, }, - FinishReason: stopFinishReason, + FinishReason: relaycommon.StopFinishReason, } content, _ = json.Marshal(candidate.Content.Parts[0].Text) if len(candidate.Content.Parts) > 0 { @@ -211,18 +149,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse return &fullTextResponse } -func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse { + var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = geminiResponse.GetResponseText() - choice.FinishReason = &stopFinishReason - var response ChatCompletionsStreamResponse + choice.FinishReason = &relaycommon.StopFinishReason + var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = "gemini" - response.Choices = []ChatCompletionsStreamResponseChoice{choice} + response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice} return &response } -func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { responseText := "" dataChan := make(chan string) stopChan := make(chan bool) @@ -252,7 +190,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW } stopChan <- true }() - setEventStreamHeaders(c) + service.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -264,14 +202,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW var dummy dummyStruct err := json.Unmarshal([]byte(data), &dummy) responseText += dummy.Content - var choice ChatCompletionsStreamResponseChoice + var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = dummy.Content - response := ChatCompletionsStreamResponse{ + response := dto.ChatCompletionsStreamResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "gemini-pro", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, } jsonResponse, err := json.Marshal(response) if err != nil { @@ -287,28 +225,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } var geminiResponse GeminiChatResponse err = json.Unmarshal(responseBody, &geminiResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if len(geminiResponse.Candidates) == 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &dto.OpenAIErrorWithStatusCode{ + OpenAIError: dto.OpenAIError{ Message: "No candidates returned", Type: "server_error", Param: "", @@ -318,8 +256,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens := countTokenText(geminiResponse.GetResponseText(), model) - usage := Usage{ + completionTokens := service.CountTokenText(geminiResponse.GetResponseText(), model) + usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, @@ -327,7 +265,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/moonshot/constants.go b/relay/channel/moonshot/constants.go new file mode 100644 index 0000000..1b86f0f --- /dev/null +++ b/relay/channel/moonshot/constants.go @@ -0,0 +1,7 @@ +package moonshot + +var ModelList = []string{ + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go new file mode 100644 index 0000000..bd01965 --- /dev/null +++ b/relay/channel/openai/adaptor.go @@ -0,0 +1,84 @@ +package openai + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + if info.ChannelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api + requestURL := strings.Split(info.RequestURLPath, "?")[0] + requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion) + task := strings.TrimPrefix(requestURL, "/v1/") + model_ := info.UpstreamModelName + model_ = strings.Replace(model_, ".", "", -1) + // https://github.com/songquanpeng/one-api/issues/67 + model_ = strings.TrimSuffix(model_, "-0301") + model_ = strings.TrimSuffix(model_, "-0314") + model_ = strings.TrimSuffix(model_, "-0613") + + requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil + } + return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + relaychannel.SetupApiRequestHeader(info, c, req) + if info.ChannelType == common.ChannelTypeAzure { + req.Header.Set("api-key", info.ApiKey) + return nil + } + req.Header.Set("Authorization", "Bearer "+info.ApiKey) + if info.ChannelType == common.ChannelTypeOpenRouter { + req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") + req.Header.Set("X-Title", "One API") + } + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return relaychannel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + var responseText string + err, responseText = openaiStreamHandler(c, resp, info.RelayMode) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + } else { + err, usage = openaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go new file mode 100644 index 0000000..91f4e51 --- /dev/null +++ b/relay/channel/openai/constant.go @@ -0,0 +1,21 @@ +package openai + +var ModelList = []string{ + "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", + "gpt-3.5-turbo-instruct", + "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", + "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", + "gpt-4-turbo-preview", + "gpt-4-vision-preview", + "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", + "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", + "text-moderation-latest", "text-moderation-stable", + "text-davinci-edit-001", + "davinci-002", "babbage-002", + "dall-e-2", "dall-e-3", + "whisper-1", + "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", +} + +var ChannelName = "openai" diff --git a/controller/relay-openai.go b/relay/channel/openai/relay-openai.go similarity index 74% rename from controller/relay-openai.go rename to relay/channel/openai/relay-openai.go index 24cfd06..c0b6353 100644 --- a/controller/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -1,4 +1,4 @@ -package controller +package openai import ( "bufio" @@ -8,12 +8,15 @@ import ( "io" "net/http" "one-api/common" + "one-api/dto" + relayconstant "one-api/relay/constant" + "one-api/service" "strings" "sync" "time" ) -func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { +func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { var responseTextBuilder strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -54,8 +57,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O } streamResp := "[" + strings.Join(streamItems, ",") + "]" switch relayMode { - case RelayModeChatCompletions: - var streamResponses []ChatCompletionsStreamResponseSimple + case relayconstant.RelayModeChatCompletions: + var streamResponses []dto.ChatCompletionsStreamResponseSimple err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -66,8 +69,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O responseTextBuilder.WriteString(choice.Delta.Content) } } - case RelayModeCompletions: - var streamResponses []CompletionsStreamResponse + case relayconstant.RelayModeCompletions: + var streamResponses []dto.CompletionsStreamResponse err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -85,7 +88,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O } common.SafeSend(stopChan, true) }() - setEventStreamHeaders(c) + service.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -102,28 +105,28 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } wg.Wait() return nil, responseTextBuilder.String() } -func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { - var textResponse TextResponse +func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var textResponse dto.TextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &textResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if textResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ + return &dto.OpenAIErrorWithStatusCode{ OpenAIError: textResponse.Error, StatusCode: resp.StatusCode, }, nil @@ -140,19 +143,19 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += countTokenText(string(choice.Message.Content), model) + completionTokens += service.CountTokenText(string(choice.Message.Content), model) } - textResponse.Usage = Usage{ + textResponse.Usage = dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go new file mode 100644 index 0000000..2a5f017 --- /dev/null +++ b/relay/channel/palm/adaptor.go @@ -0,0 +1,59 @@ +package palm + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + relaychannel.SetupApiRequestHeader(info, c, req) + req.Header.Set("x-goog-api-key", info.ApiKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return relaychannel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + var responseText string + err, responseText = palmStreamHandler(c, resp) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + } else { + err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/palm/constants.go b/relay/channel/palm/constants.go new file mode 100644 index 0000000..b5c881b --- /dev/null +++ b/relay/channel/palm/constants.go @@ -0,0 +1,7 @@ +package palm + +var ModelList = []string{ + "PaLM-2", +} + +var ChannelName = "google palm" diff --git a/relay/channel/palm/dto.go b/relay/channel/palm/dto.go new file mode 100644 index 0000000..46cf59d --- /dev/null +++ b/relay/channel/palm/dto.go @@ -0,0 +1,38 @@ +package palm + +import "one-api/dto" + +type PaLMChatMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} + +type PaLMFilter struct { + Reason string `json:"reason"` + Message string `json:"message"` +} + +type PaLMPrompt struct { + Messages []PaLMChatMessage `json:"messages"` +} + +type PaLMChatRequest struct { + Prompt PaLMPrompt `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK uint `json:"topK,omitempty"` +} + +type PaLMError struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` +} + +type PaLMChatResponse struct { + Candidates []PaLMChatMessage `json:"candidates"` + Messages []dto.Message `json:"messages"` + Filters []PaLMFilter `json:"filters"` + Error PaLMError `json:"error"` +} diff --git a/controller/relay-palm.go b/relay/channel/palm/relay-palm.go similarity index 63% rename from controller/relay-palm.go rename to relay/channel/palm/relay-palm.go index aa96c0d..20706df 100644 --- a/controller/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -1,4 +1,4 @@ -package controller +package palm import ( "encoding/json" @@ -7,47 +7,15 @@ import ( "io" "net/http" "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" ) // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body -type PaLMChatMessage struct { - Author string `json:"author"` - Content string `json:"content"` -} - -type PaLMFilter struct { - Reason string `json:"reason"` - Message string `json:"message"` -} - -type PaLMPrompt struct { - Messages []PaLMChatMessage `json:"messages"` -} - -type PaLMChatRequest struct { - Prompt PaLMPrompt `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK uint `json:"topK,omitempty"` -} - -type PaLMError struct { - Code int `json:"code"` - Message string `json:"message"` - Status string `json:"status"` -} - -type PaLMChatResponse struct { - Candidates []PaLMChatMessage `json:"candidates"` - Messages []Message `json:"messages"` - Filters []PaLMFilter `json:"filters"` - Error PaLMError `json:"error"` -} - -func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { +func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest { palmRequest := PaLMChatRequest{ Prompt: PaLMPrompt{ Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), @@ -71,15 +39,15 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { return &palmRequest } -func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ - Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), +func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse { + fullTextResponse := dto.OpenAITextResponse{ + Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { content, _ := json.Marshal(candidate.Content) - choice := OpenAITextResponseChoice{ + choice := dto.OpenAITextResponseChoice{ Index: i, - Message: Message{ + Message: dto.Message{ Role: "assistant", Content: content, }, @@ -90,20 +58,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse { + var choice dto.ChatCompletionsStreamResponseChoice if len(palmResponse.Candidates) > 0 { choice.Delta.Content = palmResponse.Candidates[0].Content } - choice.FinishReason = &stopFinishReason - var response ChatCompletionsStreamResponse + choice.FinishReason = &relaycommon.StopFinishReason + var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = "palm2" - response.Choices = []ChatCompletionsStreamResponseChoice{choice} + response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice} return &response } -func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { responseText := "" responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createdTime := common.GetTimestamp() @@ -144,7 +112,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta dataChan <- string(jsonResponse) stopChan <- true }() - setEventStreamHeaders(c) + service.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -157,28 +125,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { +func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &dto.OpenAIErrorWithStatusCode{ + OpenAIError: dto.OpenAIError{ Message: palmResponse.Error.Message, Type: palmResponse.Error.Status, Param: "", @@ -188,8 +156,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) - usage := Usage{ + completionTokens := service.CountTokenText(palmResponse.Candidates[0].Content, model) + usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, @@ -197,7 +165,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go new file mode 100644 index 0000000..58b5e0d --- /dev/null +++ b/relay/channel/tencent/adaptor.go @@ -0,0 +1,73 @@ +package tencent + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" +) + +type Adaptor struct { + Sign string +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + relaychannel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", a.Sign) + req.Header.Set("X-TC-Action", info.UpstreamModelName) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + appId, secretId, secretKey, err := parseTencentConfig(apiKey) + if err != nil { + return nil, err + } + tencentRequest := requestOpenAI2Tencent(*request) + tencentRequest.AppId = appId + tencentRequest.SecretId = secretId + // we have to calculate the sign here + a.Sign = getTencentSign(*tencentRequest, secretKey) + return tencentRequest, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return relaychannel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + var responseText string + err, responseText = tencentStreamHandler(c, resp) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + } else { + err, usage = tencentHandler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/tencent/constants.go b/relay/channel/tencent/constants.go new file mode 100644 index 0000000..7424ba4 --- /dev/null +++ b/relay/channel/tencent/constants.go @@ -0,0 +1,9 @@ +package tencent + +var ModelList = []string{ + "ChatPro", + "ChatStd", + "hunyuan", +} + +var ChannelName = "tencent" diff --git a/relay/channel/tencent/dto.go b/relay/channel/tencent/dto.go new file mode 100644 index 0000000..66f13fb --- /dev/null +++ b/relay/channel/tencent/dto.go @@ -0,0 +1,61 @@ +package tencent + +import "one-api/dto" + +type TencentMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type TencentChatRequest struct { + AppId int64 `json:"app_id"` // 腾讯云账号的 APPID + SecretId string `json:"secret_id"` // 官网 SecretId + // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 + // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 + Timestamp int64 `json:"timestamp"` + // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, + // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 + Expired int64 `json:"expired"` + QueryID string `json:"query_id"` //请求 Id,用于问题排查 + // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 + // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 + // 建议该参数和 top_p 只设置1个,不要同时更改 top_p + Temperature float64 `json:"temperature"` + // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 + // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 + // 建议该参数和 temperature 只设置1个,不要同时更改 + TopP float64 `json:"top_p"` + // Stream 0:同步,1:流式 (默认,协议:SSE) + // 同步请求超时:60s,如果内容较长建议使用流式 + Stream int `json:"stream"` + // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 + // 输入 content 总数最大支持 3000 token。 + Messages []TencentMessage `json:"messages"` +} + +type TencentError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type TencentUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type TencentResponseChoices struct { + FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 + Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 + Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 +} + +type TencentChatResponse struct { + Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 + Created string `json:"created,omitempty"` // unix 时间戳的字符串 + Id string `json:"id,omitempty"` // 会话 id + Usage dto.Usage `json:"usage,omitempty"` // token 数量 + Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"note,omitempty"` // 注释 + ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 +} diff --git a/controller/relay-tencent.go b/relay/channel/tencent/relay-tencent.go similarity index 57% rename from controller/relay-tencent.go rename to relay/channel/tencent/relay-tencent.go index 8db6492..b990c6f 100644 --- a/controller/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -1,4 +1,4 @@ -package controller +package tencent import ( "bufio" @@ -12,6 +12,9 @@ import ( "io" "net/http" "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" "sort" "strconv" "strings" @@ -19,65 +22,7 @@ import ( // https://cloud.tencent.com/document/product/1729/97732 -type TencentMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type TencentChatRequest struct { - AppId int64 `json:"app_id"` // 腾讯云账号的 APPID - SecretId string `json:"secret_id"` // 官网 SecretId - // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 - // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 - Timestamp int64 `json:"timestamp"` - // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, - // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 - Expired int64 `json:"expired"` - QueryID string `json:"query_id"` //请求 Id,用于问题排查 - // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 - // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 - // 建议该参数和 top_p 只设置1个,不要同时更改 top_p - Temperature float64 `json:"temperature"` - // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 - // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 - // 建议该参数和 temperature 只设置1个,不要同时更改 - TopP float64 `json:"top_p"` - // Stream 0:同步,1:流式 (默认,协议:SSE) - // 同步请求超时:60s,如果内容较长建议使用流式 - Stream int `json:"stream"` - // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 - // 输入 content 总数最大支持 3000 token。 - Messages []TencentMessage `json:"messages"` -} - -type TencentError struct { - Code int `json:"code"` - Message string `json:"message"` -} - -type TencentUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type TencentResponseChoices struct { - FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 - Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 - Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 -} - -type TencentChatResponse struct { - Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 - Created string `json:"created,omitempty"` // unix 时间戳的字符串 - Id string `json:"id,omitempty"` // 会话 id - Usage Usage `json:"usage,omitempty"` // token 数量 - Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 - Note string `json:"note,omitempty"` // 注释 - ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 -} - -func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { +func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest { messages := make([]TencentMessage, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] @@ -112,17 +57,17 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { } } -func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ +func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse { + fullTextResponse := dto.OpenAITextResponse{ Object: "chat.completion", Created: common.GetTimestamp(), Usage: response.Usage, } if len(response.Choices) > 0 { content, _ := json.Marshal(response.Choices[0].Messages.Content) - choice := OpenAITextResponseChoice{ + choice := dto.OpenAITextResponseChoice{ Index: 0, - Message: Message{ + Message: dto.Message{ Role: "assistant", Content: content, }, @@ -133,24 +78,24 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { - response := ChatCompletionsStreamResponse{ +func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse { + response := dto.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "tencent-hunyuan", } if len(TencentResponse.Choices) > 0 { - var choice ChatCompletionsStreamResponseChoice + var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = TencentResponse.Choices[0].Delta.Content if TencentResponse.Choices[0].FinishReason == "stop" { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &relaycommon.StopFinishReason } response.Choices = append(response.Choices, choice) } return &response } -func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { +func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { var responseText string scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -181,7 +126,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith } stopChan <- true }() - setEventStreamHeaders(c) + service.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -209,28 +154,28 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } return nil, responseText } -func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { +func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var TencentResponse TencentChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &TencentResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if TencentResponse.Error.Code != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &dto.OpenAIErrorWithStatusCode{ + OpenAIError: dto.OpenAIError{ Message: TencentResponse.Error.Message, Code: TencentResponse.Error.Code, }, @@ -240,7 +185,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus fullTextResponse := responseTencent2OpenAI(&TencentResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go new file mode 100644 index 0000000..9baebd7 --- /dev/null +++ b/relay/channel/xunfei/adaptor.go @@ -0,0 +1,68 @@ +package xunfei + +import ( + "errors" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" +) + +type Adaptor struct { + request *dto.GeneralOpenAIRequest +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + relaychannel.SetupApiRequestHeader(info, c, req) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + a.request = request + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + // xunfei's request is not http request, so we don't need to do anything here + dummyResp := &http.Response{} + dummyResp.StatusCode = http.StatusOK + return dummyResp, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + splits := strings.Split(info.ApiKey, "|") + if len(splits) != 3 { + return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + } + if a.request == nil { + return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) + } + if info.IsStream { + err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2]) + } else { + err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2]) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/xunfei/constants.go b/relay/channel/xunfei/constants.go new file mode 100644 index 0000000..80bed08 --- /dev/null +++ b/relay/channel/xunfei/constants.go @@ -0,0 +1,11 @@ +package xunfei + +var ModelList = []string{ + "SparkDesk", + "SparkDesk-v1.1", + "SparkDesk-v2.1", + "SparkDesk-v3.1", + "SparkDesk-v3.5", +} + +var ChannelName = "xunfei" diff --git a/relay/channel/xunfei/dto.go b/relay/channel/xunfei/dto.go new file mode 100644 index 0000000..5556617 --- /dev/null +++ b/relay/channel/xunfei/dto.go @@ -0,0 +1,59 @@ +package xunfei + +import "one-api/dto" + +type XunfeiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type XunfeiChatRequest struct { + Header struct { + AppId string `json:"app_id"` + } `json:"header"` + Parameter struct { + Chat struct { + Domain string `json:"domain,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` + } `json:"chat"` + } `json:"parameter"` + Payload struct { + Message struct { + Text []XunfeiMessage `json:"text"` + } `json:"message"` + } `json:"payload"` +} + +type XunfeiChatResponseTextItem struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` +} + +type XunfeiChatResponse struct { + Header struct { + Code int `json:"code"` + Message string `json:"message"` + Sid string `json:"sid"` + Status int `json:"status"` + } `json:"header"` + Payload struct { + Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []XunfeiChatResponseTextItem `json:"text"` + } `json:"choices"` + Usage struct { + //Text struct { + // QuestionTokens string `json:"question_tokens"` + // PromptTokens string `json:"prompt_tokens"` + // CompletionTokens string `json:"completion_tokens"` + // TotalTokens string `json:"total_tokens"` + //} `json:"text"` + Text dto.Usage `json:"text"` + } `json:"usage"` + } `json:"payload"` +} diff --git a/controller/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go similarity index 69% rename from controller/relay-xunfei.go rename to relay/channel/xunfei/relay-xunfei.go index c5191fd..e44b579 100644 --- a/controller/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -1,4 +1,4 @@ -package controller +package xunfei import ( "crypto/hmac" @@ -12,6 +12,9 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" "strings" "time" ) @@ -19,63 +22,7 @@ import ( // https://console.xfyun.cn/services/cbm // https://www.xfyun.cn/doc/spark/Web.html -type XunfeiMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type XunfeiChatRequest struct { - Header struct { - AppId string `json:"app_id"` - } `json:"header"` - Parameter struct { - Chat struct { - Domain string `json:"domain,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - Auditing bool `json:"auditing,omitempty"` - } `json:"chat"` - } `json:"parameter"` - Payload struct { - Message struct { - Text []XunfeiMessage `json:"text"` - } `json:"message"` - } `json:"payload"` -} - -type XunfeiChatResponseTextItem struct { - Content string `json:"content"` - Role string `json:"role"` - Index int `json:"index"` -} - -type XunfeiChatResponse struct { - Header struct { - Code int `json:"code"` - Message string `json:"message"` - Sid string `json:"sid"` - Status int `json:"status"` - } `json:"header"` - Payload struct { - Choices struct { - Status int `json:"status"` - Seq int `json:"seq"` - Text []XunfeiChatResponseTextItem `json:"text"` - } `json:"choices"` - Usage struct { - //Text struct { - // QuestionTokens string `json:"question_tokens"` - // PromptTokens string `json:"prompt_tokens"` - // CompletionTokens string `json:"completion_tokens"` - // TotalTokens string `json:"total_tokens"` - //} `json:"text"` - Text Usage `json:"text"` - } `json:"usage"` - } `json:"payload"` -} - -func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { +func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { messages := make([]XunfeiMessage, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { @@ -104,7 +51,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma return &xunfeiRequest } -func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { +func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse { if len(response.Payload.Choices.Text) == 0 { response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ { @@ -113,24 +60,24 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { } } content, _ := json.Marshal(response.Payload.Choices.Text[0].Content) - choice := OpenAITextResponseChoice{ + choice := dto.OpenAITextResponseChoice{ Index: 0, - Message: Message{ + Message: dto.Message{ Role: "assistant", Content: content, }, - FinishReason: stopFinishReason, + FinishReason: relaycommon.StopFinishReason, } - fullTextResponse := OpenAITextResponse{ + fullTextResponse := dto.OpenAITextResponse{ Object: "chat.completion", Created: common.GetTimestamp(), - Choices: []OpenAITextResponseChoice{choice}, + Choices: []dto.OpenAITextResponseChoice{choice}, Usage: response.Payload.Usage.Text, } return &fullTextResponse } -func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { +func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse { if len(xunfeiResponse.Payload.Choices.Text) == 0 { xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ { @@ -138,16 +85,16 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatComple }, } } - var choice ChatCompletionsStreamResponseChoice + var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content if xunfeiResponse.Payload.Choices.Status == 2 { - choice.FinishReason = &stopFinishReason + choice.FinishReason = &relaycommon.StopFinishReason } - response := ChatCompletionsStreamResponse{ + response := dto.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "SparkDesk", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, } return &response } @@ -178,14 +125,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { return callUrl } -func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { +func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - setEventStreamHeaders(c) - var usage Usage + service.SetEventStreamHeaders(c) + var usage dto.Usage c.Stream(func(w io.Writer) bool { select { case xunfeiResponse := <-dataChan: @@ -208,13 +155,13 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId return nil, &usage } -func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { +func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) if err != nil { - return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - var usage Usage + var usage dto.Usage var content string var xunfeiResponse XunfeiChatResponse stop := false @@ -237,14 +184,14 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin response := responseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") _, _ = c.Writer.Write(jsonResponse) return nil, &usage } -func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { +func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go new file mode 100644 index 0000000..6fd3047 --- /dev/null +++ b/relay/channel/zhipu/adaptor.go @@ -0,0 +1,61 @@ +package zhipu + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + method := "invoke" + if info.IsStream { + method = "sse-invoke" + } + return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + relaychannel.SetupApiRequestHeader(info, c, req) + token := getZhipuToken(info.ApiKey) + req.Header.Set("Authorization", token) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return requestOpenAI2Zhipu(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return relaychannel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = zhipuStreamHandler(c, resp) + } else { + err, usage = zhipuHandler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go new file mode 100644 index 0000000..81b18d6 --- /dev/null +++ b/relay/channel/zhipu/constants.go @@ -0,0 +1,7 @@ +package zhipu + +var ModelList = []string{ + "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", +} + +var ChannelName = "zhipu" diff --git a/relay/channel/zhipu/dto.go b/relay/channel/zhipu/dto.go new file mode 100644 index 0000000..1040124 --- /dev/null +++ b/relay/channel/zhipu/dto.go @@ -0,0 +1,46 @@ +package zhipu + +import ( + "one-api/dto" + "time" +) + +type ZhipuMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ZhipuRequest struct { + Prompt []ZhipuMessage `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + RequestId string `json:"request_id,omitempty"` + Incremental bool `json:"incremental,omitempty"` +} + +type ZhipuResponseData struct { + TaskId string `json:"task_id"` + RequestId string `json:"request_id"` + TaskStatus string `json:"task_status"` + Choices []ZhipuMessage `json:"choices"` + dto.Usage `json:"usage"` +} + +type ZhipuResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Success bool `json:"success"` + Data ZhipuResponseData `json:"data"` +} + +type ZhipuStreamMetaResponse struct { + RequestId string `json:"request_id"` + TaskId string `json:"task_id"` + TaskStatus string `json:"task_status"` + dto.Usage `json:"usage"` +} + +type zhipuTokenData struct { + Token string + ExpiryTime time.Time +} diff --git a/controller/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go similarity index 67% rename from controller/relay-zhipu.go rename to relay/channel/zhipu/relay-zhipu.go index bc8d201..d6d82f1 100644 --- a/controller/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -1,4 +1,4 @@ -package controller +package zhipu import ( "bufio" @@ -8,6 +8,9 @@ import ( "io" "net/http" "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" "strings" "sync" "time" @@ -18,46 +21,6 @@ import ( // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke -type ZhipuMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type ZhipuRequest struct { - Prompt []ZhipuMessage `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - RequestId string `json:"request_id,omitempty"` - Incremental bool `json:"incremental,omitempty"` -} - -type ZhipuResponseData struct { - TaskId string `json:"task_id"` - RequestId string `json:"request_id"` - TaskStatus string `json:"task_status"` - Choices []ZhipuMessage `json:"choices"` - Usage `json:"usage"` -} - -type ZhipuResponse struct { - Code int `json:"code"` - Msg string `json:"msg"` - Success bool `json:"success"` - Data ZhipuResponseData `json:"data"` -} - -type ZhipuStreamMetaResponse struct { - RequestId string `json:"request_id"` - TaskId string `json:"task_id"` - TaskStatus string `json:"task_status"` - Usage `json:"usage"` -} - -type zhipuTokenData struct { - Token string - ExpiryTime time.Time -} - var zhipuTokens sync.Map var expSeconds int64 = 24 * 3600 @@ -108,7 +71,7 @@ func getZhipuToken(apikey string) string { return tokenString } -func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { +func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest { messages := make([]ZhipuMessage, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { @@ -135,19 +98,19 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { } } -func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { - fullTextResponse := OpenAITextResponse{ +func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse { + fullTextResponse := dto.OpenAITextResponse{ Id: response.Data.TaskId, Object: "chat.completion", Created: common.GetTimestamp(), - Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), + Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Data.Choices)), Usage: response.Data.Usage, } for i, choice := range response.Data.Choices { content, _ := json.Marshal(strings.Trim(choice.Content, "\"")) - openaiChoice := OpenAITextResponseChoice{ + openaiChoice := dto.OpenAITextResponseChoice{ Index: i, - Message: Message{ + Message: dto.Message{ Role: choice.Role, Content: content, }, @@ -161,34 +124,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { return &fullTextResponse } -func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { - var choice ChatCompletionsStreamResponseChoice +func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse { + var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = zhipuResponse - response := ChatCompletionsStreamResponse{ + response := dto.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "chatglm", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, } return &response } -func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { - var choice ChatCompletionsStreamResponseChoice +func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) { + var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = "" - choice.FinishReason = &stopFinishReason - response := ChatCompletionsStreamResponse{ + choice.FinishReason = &relaycommon.StopFinishReason + response := dto.ChatCompletionsStreamResponse{ Id: zhipuResponse.RequestId, Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "chatglm", - Choices: []ChatCompletionsStreamResponseChoice{choice}, + Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, } return &response, &zhipuResponse.Usage } -func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var usage *Usage +func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var usage *dto.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -225,7 +188,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt } stopChan <- true }() - setEventStreamHeaders(c) + service.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -260,28 +223,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, usage } -func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { +func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var zhipuResponse ZhipuResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if !zhipuResponse.Success { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ + return &dto.OpenAIErrorWithStatusCode{ + OpenAIError: dto.OpenAIError{ Message: zhipuResponse.Msg, Type: "zhipu_error", Param: "", @@ -293,7 +256,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go new file mode 100644 index 0000000..62302a0 --- /dev/null +++ b/relay/common/relay_info.go @@ -0,0 +1,71 @@ +package common + +import ( + "github.com/gin-gonic/gin" + "one-api/common" + "one-api/relay/constant" + "strings" + "time" +) + +type RelayInfo struct { + ChannelType int + ChannelId int + TokenId int + UserId int + Group string + TokenUnlimited bool + StartTime time.Time + ApiType int + IsStream bool + RelayMode int + UpstreamModelName string + RequestURLPath string + ApiVersion string + PromptTokens int + ApiKey string + BaseUrl string +} + +func GenRelayInfo(c *gin.Context) *RelayInfo { + channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + group := c.GetString("group") + tokenUnlimited := c.GetBool("token_unlimited_quota") + startTime := time.Now() + + apiType := constant.ChannelType2APIType(channelType) + + info := &RelayInfo{ + RelayMode: constant.Path2RelayMode(c.Request.URL.Path), + BaseUrl: c.GetString("base_url"), + RequestURLPath: c.Request.URL.String(), + ChannelType: channelType, + ChannelId: channelId, + TokenId: tokenId, + UserId: userId, + Group: group, + TokenUnlimited: tokenUnlimited, + StartTime: startTime, + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + } + if info.BaseUrl == "" { + info.BaseUrl = common.ChannelBaseURLs[channelType] + } + //if info.ChannelType == common.ChannelTypeAzure { + // info.ApiVersion = GetAzureAPIVersion(c) + //} + return info +} + +func (info *RelayInfo) SetPromptTokens(promptTokens int) { + info.PromptTokens = promptTokens +} + +func (info *RelayInfo) SetIsStream(isStream bool) { + info.IsStream = isStream +} diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go new file mode 100644 index 0000000..0ab38cb --- /dev/null +++ b/relay/common/relay_utils.go @@ -0,0 +1,68 @@ +package common + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "io" + "net/http" + "one-api/common" + "one-api/dto" + "strconv" + "strings" +) + +var StopFinishReason = "stop" + +func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { + openAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{ + StatusCode: resp.StatusCode, + OpenAIError: dto.OpenAIError{ + Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + err = resp.Body.Close() + if err != nil { + return + } + var textResponse dto.TextResponse + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return + } + openAIErrorWithStatusCode.OpenAIError = textResponse.Error + return +} + +func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + switch channelType { + case common.ChannelTypeOpenAI: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + case common.ChannelTypeAzure: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) + } + } + return fullRequestURL +} + +func GetAPIVersion(c *gin.Context) string { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + return apiVersion +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go new file mode 100644 index 0000000..485f8e0 --- /dev/null +++ b/relay/constant/api_type.go @@ -0,0 +1,45 @@ +package constant + +import ( + "one-api/common" +) + +const ( + APITypeOpenAI = iota + APITypeAnthropic + APITypePaLM + APITypeBaidu + APITypeZhipu + APITypeAli + APITypeXunfei + APITypeAIProxyLibrary + APITypeTencent + APITypeGemini + + APITypeDummy // this one is only for count, do not add any channel after this +) + +func ChannelType2APIType(channelType int) int { + apiType := APITypeOpenAI + switch channelType { + case common.ChannelTypeAnthropic: + apiType = APITypeAnthropic + case common.ChannelTypeBaidu: + apiType = APITypeBaidu + case common.ChannelTypePaLM: + apiType = APITypePaLM + case common.ChannelTypeZhipu: + apiType = APITypeZhipu + case common.ChannelTypeAli: + apiType = APITypeAli + case common.ChannelTypeXunfei: + apiType = APITypeXunfei + case common.ChannelTypeAIProxyLibrary: + apiType = APITypeAIProxyLibrary + case common.ChannelTypeTencent: + apiType = APITypeTencent + case common.ChannelTypeGemini: + apiType = APITypeGemini + } + return apiType +} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go new file mode 100644 index 0000000..beea7dc --- /dev/null +++ b/relay/constant/relay_mode.go @@ -0,0 +1,50 @@ +package constant + +import "strings" + +const ( + RelayModeUnknown = iota + RelayModeChatCompletions + RelayModeCompletions + RelayModeEmbeddings + RelayModeModerations + RelayModeImagesGenerations + RelayModeEdits + RelayModeMidjourneyImagine + RelayModeMidjourneyDescribe + RelayModeMidjourneyBlend + RelayModeMidjourneyChange + RelayModeMidjourneySimpleChange + RelayModeMidjourneyNotify + RelayModeMidjourneyTaskFetch + RelayModeMidjourneyTaskFetchByCondition + RelayModeAudioSpeech + RelayModeAudioTranscription + RelayModeAudioTranslation +) + +func Path2RelayMode(path string) int { + relayMode := RelayModeUnknown + if strings.HasPrefix(path, "/v1/chat/completions") { + relayMode = RelayModeChatCompletions + } else if strings.HasPrefix(path, "/v1/completions") { + relayMode = RelayModeCompletions + } else if strings.HasPrefix(path, "/v1/embeddings") { + relayMode = RelayModeEmbeddings + } else if strings.HasSuffix(path, "embeddings") { + relayMode = RelayModeEmbeddings + } else if strings.HasPrefix(path, "/v1/moderations") { + relayMode = RelayModeModerations + } else if strings.HasPrefix(path, "/v1/images/generations") { + relayMode = RelayModeImagesGenerations + } else if strings.HasPrefix(path, "/v1/edits") { + relayMode = RelayModeEdits + } else if strings.HasPrefix(path, "/v1/audio/speech") { + relayMode = RelayModeAudioSpeech + } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { + relayMode = RelayModeAudioTranscription + } else if strings.HasPrefix(path, "/v1/audio/translations") { + relayMode = RelayModeAudioTranslation + } + return relayMode +} diff --git a/controller/relay-audio.go b/relay/relay-audio.go similarity index 67% rename from controller/relay-audio.go rename to relay/relay-audio.go index 63f8563..1a62fff 100644 --- a/controller/relay-audio.go +++ b/relay/relay-audio.go @@ -1,4 +1,4 @@ -package controller +package relay import ( "bytes" @@ -10,7 +10,10 @@ import ( "io" "net/http" "one-api/common" + "one-api/controller" + "one-api/dto" "one-api/model" + "one-api/service" "strings" "time" ) @@ -24,7 +27,7 @@ var availableVoices = []string{ "shimmer", } -func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { +func RelayAudioHelper(c *gin.Context, relayMode int) *controller.OpenAIErrorWithStatusCode { tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") @@ -36,7 +39,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { err := common.UnmarshalBodyReusable(c, &audioRequest) if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } } else { audioRequest = AudioRequest{ @@ -47,15 +50,15 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode // request validation if audioRequest.Model == "" { - return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) + return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) } if strings.HasPrefix(audioRequest.Model, "tts-1") { if audioRequest.Voice == "" { - return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest) + return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest) } if !common.StringsContains(availableVoices, audioRequest.Voice) { - return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest) + return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest) } } @@ -66,14 +69,14 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode preConsumedQuota := int(float64(preConsumedTokens) * ratio) userQuota, err := model.CacheGetUserQuota(userId) if err != nil { - return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // in this case, we do not pre-consume quota @@ -83,7 +86,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if preConsumedQuota > 0 { userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } @@ -93,7 +96,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[audioRequest.Model] != "" { audioRequest.Model = modelMap[audioRequest.Model] @@ -106,10 +109,10 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + fullRequestURL := common.getFullRequestURL(baseURL, requestURL, channelType) if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiVersion := GetAPIVersion(c) + apiVersion := common.GetAPIVersion(c) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion) } @@ -117,7 +120,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { @@ -133,25 +136,25 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := httpClient.Do(req) + resp, err := controller.httpClient.Do(req) if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } err = req.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } err = c.Request.Body.Close() if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } if resp.StatusCode != http.StatusOK { - return relayErrorHandler(resp) + return common.relayErrorHandler(resp) } - var audioResponse AudioResponse + var audioResponse dto.AudioResponse defer func(ctx context.Context) { go func() { @@ -159,10 +162,10 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode quota := 0 var promptTokens = 0 if strings.HasPrefix(audioRequest.Model, "tts-1") { - quota = countAudioToken(audioRequest.Input, audioRequest.Model) + quota = service.countAudioToken(audioRequest.Input, audioRequest.Model) promptTokens = quota } else { - quota = countAudioToken(audioResponse.Text, audioRequest.Model) + quota = service.countAudioToken(audioResponse.Text, audioRequest.Model) } quota = int(float64(quota) * ratio) if ratio != 0 && quota <= 0 { @@ -191,18 +194,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode responseBody, err := io.ReadAll(resp.Body) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } if strings.HasPrefix(audioRequest.Model, "tts-1") { } else { err = json.Unmarshal(responseBody, &audioResponse) if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } } @@ -215,11 +218,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } return nil } diff --git a/controller/relay-image.go b/relay/relay-image.go similarity index 96% rename from controller/relay-image.go rename to relay/relay-image.go index 39f3308..e717c3e 100644 --- a/controller/relay-image.go +++ b/relay/relay-image.go @@ -1,4 +1,4 @@ -package controller +package relay import ( "bytes" @@ -10,12 +10,15 @@ import ( "io" "net/http" "one-api/common" + "one-api/controller" + "one-api/dto" "one-api/model" + "one-api/relay/common" "strings" "time" ) -func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { +func RelayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") @@ -24,7 +27,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode group := c.GetString("group") startTime := time.Now() - var imageRequest ImageRequest + var imageRequest dto.ImageRequest if consumeQuota { err := common.UnmarshalBodyReusable(c, &imageRequest) if err != nil { @@ -90,7 +93,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := GetAPIVersion(c) + apiVersion := common.GetAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion) } @@ -151,7 +154,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := httpClient.Do(req) + resp, err := controller.httpClient.Do(req) if err != nil { return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) } diff --git a/controller/relay-mj.go b/relay/relay-mj.go similarity index 97% rename from controller/relay-mj.go rename to relay/relay-mj.go index e5b377a..36114bf 100644 --- a/controller/relay-mj.go +++ b/relay/relay-mj.go @@ -1,4 +1,4 @@ -package controller +package relay import ( "bytes" @@ -9,6 +9,7 @@ import ( "log" "net/http" "one-api/common" + "one-api/controller" "one-api/model" "strconv" "strings" @@ -104,7 +105,7 @@ func RelayMidjourneyImage(c *gin.Context) { return } -func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { +func RelayMidjourneyNotify(c *gin.Context) *MidjourneyResponse { var midjRequest Midjourney err := common.UnmarshalBodyReusable(c, &midjRequest) if err != nil { @@ -167,7 +168,7 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo return } -func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { +func RelayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse { userId := c.GetInt("id") var err error var respBody []byte @@ -244,7 +245,7 @@ const ( MJSubmitActionUpscale = "UPSCALE" // 放大 ) -func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { +func RelayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { imageModel := "midjourney" tokenId := c.GetInt("token_id") @@ -427,21 +428,21 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { Description: "create_request_failed", } } - //req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + //req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) //mjToken := "" - //if c.Request.Header.Get("Authorization") != "" { - // mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1] + //if c.Request.Header.Get("ApiKey") != "" { + // mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1] //} - //req.Header.Set("Authorization", "Bearer midjourney-proxy") + //req.Header.Set("ApiKey", "Bearer midjourney-proxy") req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1]) // print request header log.Printf("request header: %s", req.Header) log.Printf("request body: %s", midjRequest.Prompt) - resp, err := httpClient.Do(req) + resp, err := controller.httpClient.Do(req) if err != nil { return &MidjourneyResponse{ Code: 4, diff --git a/relay/relay-text.go b/relay/relay-text.go new file mode 100644 index 0000000..4c29b04 --- /dev/null +++ b/relay/relay-text.go @@ -0,0 +1,277 @@ +package relay + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/model" + relaychannel "one-api/relay/channel" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" + "one-api/service" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) { + textRequest := &dto.GeneralOpenAIRequest{} + err := common.UnmarshalBodyReusable(c, textRequest) + if err != nil { + return nil, err + } + if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } + if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } + + if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { + return nil, errors.New("max_tokens is invalid") + } + if textRequest.Model == "" { + return nil, errors.New("model is required") + } + switch relayInfo.RelayMode { + case relayconstant.RelayModeCompletions: + if textRequest.Prompt == "" { + return nil, errors.New("field prompt is required") + } + case relayconstant.RelayModeChatCompletions: + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return nil, errors.New("field messages is required") + } + case relayconstant.RelayModeEmbeddings: + case relayconstant.RelayModeModerations: + if textRequest.Input == "" { + return nil, errors.New("field input is required") + } + case relayconstant.RelayModeEdits: + if textRequest.Instruction == "" { + return nil, errors.New("field instruction is required") + } + } + relayInfo.IsStream = textRequest.Stream + return textRequest, nil +} + +func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { + + relayInfo := relaycommon.GenRelayInfo(c) + + // get & validate textRequest 获取并验证文本请求 + textRequest, err := getAndValidateTextRequest(c, relayInfo) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) + } + + // map model name + modelMapping := c.GetString("model_mapping") + isModelMapped := false + if modelMapping != "" && modelMapping != "{}" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[textRequest.Model] != "" { + textRequest.Model = modelMap[textRequest.Model] + isModelMapped = true + } + } + modelPrice := common.GetModelPrice(textRequest.Model, false) + groupRatio := common.GetGroupRatio(relayInfo.Group) + + var preConsumedQuota int + var ratio float64 + var modelRatio float64 + promptTokens, err := getPromptTokens(textRequest, relayInfo) + + // count messages token error 计算promptTokens错误 + if err != nil { + return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) + } + + if modelPrice == -1 { + preConsumedTokens := common.PreConsumedQuota + if textRequest.MaxTokens != 0 { + preConsumedTokens = promptTokens + int(textRequest.MaxTokens) + } + modelRatio = common.GetModelRatio(textRequest.Model) + ratio = modelRatio * groupRatio + preConsumedQuota = int(float64(preConsumedTokens) * ratio) + } else { + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + + // pre-consume quota 预消耗配额 + userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + if err != nil { + return openaiErr + } + + adaptor := relaychannel.GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo, *textRequest) + var requestBody io.Reader + if relayInfo.ApiType == relayconstant.APITypeOpenAI { + if isModelMapped { + jsonStr, err := json.Marshal(textRequest) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } + } else { + convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest) + if err != nil { + return service.OpenAIErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError) + } + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonData) + } + + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + + usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + if openaiErr != nil { + return openaiErr + } + + postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) + return nil +} + +func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) { + var promptTokens int + var err error + + switch info.RelayMode { + case relayconstant.RelayModeChatCompletions: + promptTokens, err = service.CountTokenMessages(textRequest.Messages, textRequest.Model) + case relayconstant.RelayModeCompletions: + promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil + case relayconstant.RelayModeModerations: + promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil + default: + err = errors.New("unknown relay mode") + promptTokens = 0 + } + info.PromptTokens = promptTokens + return promptTokens, err +} + +// 预扣费并返回用户剩余配额 +func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *dto.OpenAIErrorWithStatusCode) { + userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) + if err != nil { + return 0, service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + if userQuota < 0 || userQuota-preConsumedQuota < 0 { + return 0, service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) + if err != nil { + return 0, service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // 用户额度充足,判断令牌额度是否充足 + if !relayInfo.TokenUnlimited { + // 非无限令牌,判断令牌额度是否充足 + tokenQuota := c.GetInt("token_quota") + if tokenQuota > 100*preConsumedQuota { + // 令牌额度充足,信任令牌 + preConsumedQuota = 0 + common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) + } + } else { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) + } + } + if preConsumedQuota > 0 { + userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota) + if err != nil { + return 0, service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + return userQuota, nil +} + +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) { + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() + promptTokens := usage.PromptTokens + completionTokens := usage.CompletionTokens + + tokenName := ctx.GetString("token_name") + + quota := 0 + if modelPrice == -1 { + completionRatio := common.GetCompletionRatio(textRequest.Model) + quota = promptTokens + int(float64(completionTokens)*completionRatio) + quota = int(float64(quota) * ratio) + if ratio != 0 && quota <= 0 { + quota = 1 + } + } else { + quota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + totalTokens := promptTokens + completionTokens + var logContent string + if modelPrice == -1 { + logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + } else { + logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) + } + + // record all the consume log even if quota is 0 + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + logContent += fmt.Sprintf("(可能是上游超时)") + common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota)) + } else { + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + err = model.CacheUpdateUserQuota(relayInfo.UserId) + if err != nil { + common.LogError(ctx, "error update user quota cache: "+err.Error()) + } + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + + logModel := textRequest.Model + if strings.HasPrefix(logModel, "gpt-4-gizmo") { + logModel = "gpt-4-gizmo-*" + logContent += fmt.Sprintf(",模型 %s", textRequest.Model) + } + model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream) + + //if quota != 0 { + // + //} +} diff --git a/router/relay-router.go b/router/relay-router.go index fd80b30..6a30a5a 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -1,10 +1,10 @@ package router import ( + "github.com/gin-gonic/gin" "one-api/controller" "one-api/middleware" - - "github.com/gin-gonic/gin" + "one-api/relay" ) func SetRelayRouter(router *gin.Engine) { @@ -44,7 +44,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/moderations", controller.Relay) } relayMjRouter := router.Group("/mj") - relayMjRouter.GET("/image/:id", controller.RelayMidjourneyImage) + relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage) relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney) diff --git a/service/channel.go b/service/channel.go new file mode 100644 index 0000000..b9a7627 --- /dev/null +++ b/service/channel.go @@ -0,0 +1,53 @@ +package service + +import ( + "fmt" + "net/http" + "one-api/common" + relaymodel "one-api/dto" + "one-api/model" +) + +// disable & notify +func DisableChannel(channelId int, channelName string, reason string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + notifyRootUser(subject, content) +} + +func EnableChannel(channelId int, channelName string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + notifyRootUser(subject, content) +} + +func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool { + if !common.AutomaticDisableChannelEnabled { + return false + } + if err == nil { + return false + } + if statusCode == http.StatusUnauthorized { + return true + } + if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" { + return true + } + return false +} + +func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError) bool { + if !common.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if openAIErr != nil { + return false + } + return true +} diff --git a/service/error.go b/service/error.go new file mode 100644 index 0000000..89d200c --- /dev/null +++ b/service/error.go @@ -0,0 +1,29 @@ +package service + +import ( + "fmt" + "one-api/common" + "one-api/dto" + "strings" +) + +// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode +func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { + text := err.Error() + // 定义一个正则表达式匹配URL + if strings.Contains(text, "Post") { + common.SysLog(fmt.Sprintf("error: %s", text)) + text = "请求上游地址失败" + } + //避免暴露内部错误 + + openAIError := dto.OpenAIError{ + Message: text, + Type: "new_api_error", + Code: code, + } + return &dto.OpenAIErrorWithStatusCode{ + OpenAIError: openAIError, + StatusCode: statusCode, + } +} diff --git a/service/http_client.go b/service/http_client.go new file mode 100644 index 0000000..df920bd --- /dev/null +++ b/service/http_client.go @@ -0,0 +1,32 @@ +package service + +import ( + "net/http" + "one-api/common" + "time" +) + +var httpClient *http.Client +var impatientHTTPClient *http.Client + +func init() { + if common.RelayTimeout == 0 { + httpClient = &http.Client{} + } else { + httpClient = &http.Client{ + Timeout: time.Duration(common.RelayTimeout) * time.Second, + } + } + + impatientHTTPClient = &http.Client{ + Timeout: 5 * time.Second, + } +} + +func GetHttpClient() *http.Client { + return httpClient +} + +func GetImpatientHttpClient() *http.Client { + return impatientHTTPClient +} diff --git a/service/sse.go b/service/sse.go new file mode 100644 index 0000000..4e86bad --- /dev/null +++ b/service/sse.go @@ -0,0 +1,11 @@ +package service + +import "github.com/gin-gonic/gin" + +func SetEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} diff --git a/controller/relay-utils.go b/service/token_counter.go similarity index 61% rename from controller/relay-utils.go rename to service/token_counter.go index 2efc1de..179eccd 100644 --- a/controller/relay-utils.go +++ b/service/token_counter.go @@ -1,27 +1,19 @@ -package controller +package service import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" "image" - _ "image/gif" - _ "image/jpeg" - _ "image/png" - "io" "log" "math" - "net/http" "one-api/common" - "strconv" + "one-api/dto" "strings" "unicode/utf8" ) -var stopFinishReason = "stop" - // tokenEncoderMap won't grow after initialization var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var defaultTokenEncoder *tiktoken.Tiktoken @@ -70,7 +62,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func getImageToken(imageUrl *MessageImageUrl) (int, error) { +func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) { if imageUrl.Detail == "low" { return 85, nil } @@ -124,7 +116,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) { return tiles*170 + 85, nil } -func countTokenMessages(messages []Message, model string) (int, error) { +func CountTokenMessages(messages []dto.Message, model string) (int, error) { //recover when panic tokenEncoder := getTokenEncoder(model) // Reference: @@ -146,7 +138,7 @@ func countTokenMessages(messages []Message, model string) (int, error) { tokenNum += tokensPerMessage tokenNum += getTokenNum(tokenEncoder, message.Role) if len(message.Content) > 0 { - var arrayContent []MediaMessage + var arrayContent []dto.MediaMessage if err := json.Unmarshal(message.Content, &arrayContent); err != nil { var stringContent string if err := json.Unmarshal(message.Content, &stringContent); err != nil { @@ -163,7 +155,7 @@ func countTokenMessages(messages []Message, model string) (int, error) { if m.Type == "image_url" { var imageTokenNum int if str, ok := m.ImageUrl.(string); ok { - imageTokenNum, err = getImageToken(&MessageImageUrl{Url: str, Detail: "auto"}) + imageTokenNum, err = getImageToken(&dto.MessageImageUrl{Url: str, Detail: "auto"}) } else { imageUrlMap := m.ImageUrl.(map[string]interface{}) detail, ok := imageUrlMap["detail"] @@ -172,7 +164,7 @@ func countTokenMessages(messages []Message, model string) (int, error) { } else { imageUrlMap["detail"] = "auto" } - imageUrl := MessageImageUrl{ + imageUrl := dto.MessageImageUrl{ Url: imageUrlMap["url"].(string), Detail: imageUrlMap["detail"].(string), } @@ -195,16 +187,16 @@ func countTokenMessages(messages []Message, model string) (int, error) { return tokenNum, nil } -func countTokenInput(input any, model string) int { +func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: - return countTokenText(v, model) + return CountTokenText(v, model) case []string: text := "" for _, s := range v { text += s } - return countTokenText(text, model) + return CountTokenText(text, model) } return 0 } @@ -213,118 +205,11 @@ func countAudioToken(text string, model string) int { if strings.HasPrefix(model, "tts") { return utf8.RuneCountInString(text) } else { - return countTokenText(text, model) + return CountTokenText(text, model) } } -func countTokenText(text string, model string) int { +func CountTokenText(text string, model string) int { tokenEncoder := getTokenEncoder(model) return getTokenNum(tokenEncoder, text) } - -func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { - text := err.Error() - // 定义一个正则表达式匹配URL - if strings.Contains(text, "Post") { - common.SysLog(fmt.Sprintf("error: %s", text)) - text = "请求上游地址失败" - } - //避免暴露内部错误 - - openAIError := OpenAIError{ - Message: text, - Type: "new_api_error", - Code: code, - } - return &OpenAIErrorWithStatusCode{ - OpenAIError: openAIError, - StatusCode: statusCode, - } -} - -func shouldDisableChannel(err *OpenAIError, statusCode int) bool { - if !common.AutomaticDisableChannelEnabled { - return false - } - if err == nil { - return false - } - if statusCode == http.StatusUnauthorized { - return true - } - if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" { - return true - } - return false -} - -func shouldEnableChannel(err error, openAIErr *OpenAIError) bool { - if !common.AutomaticEnableChannelEnabled { - return false - } - if err != nil { - return false - } - if openAIErr != nil { - return false - } - return true -} - -func setEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") -} - -func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { - openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ - StatusCode: resp.StatusCode, - OpenAIError: OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - var textResponse TextResponse - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return - } - openAIErrorWithStatusCode.OpenAIError = textResponse.Error - return -} - -func getFullRequestURL(baseURL string, requestURL string, channelType int) string { - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - switch channelType { - case common.ChannelTypeOpenAI: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - case common.ChannelTypeAzure: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) - } - } - return fullRequestURL -} - -func GetAPIVersion(c *gin.Context) string { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } - return apiVersion -} diff --git a/service/usage_helpr.go b/service/usage_helpr.go new file mode 100644 index 0000000..c1fcfb5 --- /dev/null +++ b/service/usage_helpr.go @@ -0,0 +1,27 @@ +package service + +import ( + "errors" + "one-api/dto" + "one-api/relay/constant" +) + +func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) { + switch relayMode { + case constant.RelayModeChatCompletions: + return CountTokenMessages(textRequest.Messages, textRequest.Model) + case constant.RelayModeCompletions: + return CountTokenInput(textRequest.Prompt, textRequest.Model), nil + case constant.RelayModeModerations: + return CountTokenInput(textRequest.Input, textRequest.Model), nil + } + return 0, errors.New("unknown relay mode") +} + +func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage { + usage := &dto.Usage{} + usage.PromptTokens = promptTokens + usage.CompletionTokens = CountTokenText(responseText, modeName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return usage +} diff --git a/service/user_notify.go b/service/user_notify.go new file mode 100644 index 0000000..7ae9062 --- /dev/null +++ b/service/user_notify.go @@ -0,0 +1,17 @@ +package service + +import ( + "fmt" + "one-api/common" + "one-api/model" +) + +func notifyRootUser(subject string, content string) { + if common.RootUserEmail == "" { + common.RootUserEmail = model.GetRootUserEmail() + } + err := common.SendEmail(subject, common.RootUserEmail, content) + if err != nil { + common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } +}