From 8a730cfe12cb9c84abcfa95f08c166618d436ab7 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 6 Jul 2024 18:42:48 +0800 Subject: [PATCH] feat: support jina rerank --- README.md | 4 +- common/constants.go | 3 ++ relay/channel/jina/adaptor.go | 64 ++++++++++++++++++++++++++ relay/channel/jina/constant.go | 8 ++++ relay/channel/jina/relay-jina.go | 35 ++++++++++++++ relay/constant/api_type.go | 3 ++ relay/relay_adaptor.go | 3 ++ web/src/constants/channel.constants.js | 3 +- 8 files changed, 120 insertions(+), 3 deletions(-) create mode 100644 relay/channel/jina/adaptor.go create mode 100644 relay/channel/jina/constant.go create mode 100644 relay/channel/jina/relay-jina.go diff --git a/README.md b/README.md index 5a55362..fff4f20 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ 3. 选择你的bot,然后输入http(s)://你的网站地址/login 4. Telegram Bot 名称是bot username 去掉@后的字符串 13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md) -14. 支持Rerank模型,目前仅兼容Cohere的rerank,可接入Dify,[对接文档](Rerank.md) +14. 支持Rerank模型,目前仅兼容Cohere和Jina,可接入Dify,[对接文档](Rerank.md) ## 模型支持 此版本额外支持以下模型: @@ -46,7 +46,7 @@ 6. [零一万物](https://platform.lingyiwanwu.com/) 7. 自定义渠道,支持填入完整调用地址 8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md) -9. Rerank模型,目前仅支持[Cohere](https://cohere.ai/) +9. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md) 10. Dify 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 diff --git a/common/constants.go b/common/constants.go index 386f37d..5a53283 100644 --- a/common/constants.go +++ b/common/constants.go @@ -211,6 +211,7 @@ const ( ChannelTypeMiniMax = 35 ChannelTypeSunoAPI = 36 ChannelTypeDify = 37 + ChannelTypeJina = 38 ChannelTypeDummy // this one is only for count, do not add any channel after this @@ -254,4 +255,6 @@ var ChannelBaseURLs = []string{ "https://api.cohere.ai", //34 "https://api.minimax.chat", //35 "", //36 + "", //37 + "https://api.jina.ai", //38 } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go new file mode 100644 index 0000000..48616b6 --- /dev/null +++ b/relay/channel/jina/adaptor.go @@ -0,0 +1,64 @@ +package jina + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/relay/constant" +) + +type Adaptor struct { +} + +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + if info.RelayMode == constant.RelayModeRerank { + return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + } else if info.RelayMode == constant.RelayModeEmbeddings { + return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil + } + return "", errors.New("invalid relay mode") +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return request, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.RelayMode == constant.RelayModeRerank { + err, usage = jinaRerankHandler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/jina/constant.go b/relay/channel/jina/constant.go new file mode 100644 index 0000000..45fc44c --- /dev/null +++ b/relay/channel/jina/constant.go @@ -0,0 +1,8 @@ +package jina + +var ModelList = []string{ + "jina-clip-v1", + "jina-reranker-v2-base-multilingual", +} + +var ChannelName = "jina" diff --git a/relay/channel/jina/relay-jina.go b/relay/channel/jina/relay-jina.go new file mode 100644 index 0000000..5fdd44f --- /dev/null +++ b/relay/channel/jina/relay-jina.go @@ -0,0 +1,35 @@ +package jina + +import ( + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/service" +) + +func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var jinaResp dto.RerankResponse + err = json.Unmarshal(responseBody, &jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + jsonResponse, err := json.Marshal(jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &jinaResp.Usage +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 15ba541..0ce2657 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -21,6 +21,7 @@ const ( APITypeAws APITypeCohere APITypeDify + APITypeJina APITypeDummy // this one is only for count, do not add any channel after this ) @@ -60,6 +61,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeCohere case common.ChannelTypeDify: apiType = APITypeDify + case common.ChannelTypeJina: + apiType = APITypeJina } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 36edef3..8998540 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel/cohere" "one-api/relay/channel/dify" "one-api/relay/channel/gemini" + "one-api/relay/channel/jina" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" "one-api/relay/channel/palm" @@ -56,6 +57,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &cohere.Adaptor{} case constant.APITypeDify: return &dify.Adaptor{} + case constant.APITypeJina: + return &jina.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 9383f21..ff1d281 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -104,7 +104,8 @@ export const CHANNEL_OPTIONS = [ { key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' }, { key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' }, { key: 35, text: 'MiniMax', value: 35, color: 'green', label: 'MiniMax' }, - { key: 37, text: 'Dify', value: 37, color: 'green', label: 'Dify' }, + { key: 37, text: 'Dify', value: 37, color: 'teal', label: 'Dify' }, + { key: 38, text: 'Jina', value: 38, color: 'blue', label: 'Jina' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' }, { key: 22,