diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 6a04d08..f296ed0 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -32,7 +32,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil } return "", errors.New("invalid relay mode") } @@ -58,6 +58,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = jinaRerankHandler(c, resp) + } else if info.RelayMode == constant.RelayModeEmbeddings { + err, usage = jinaEmbeddingHandler(c, resp) } return } diff --git a/relay/channel/jina/relay-jina.go b/relay/channel/jina/relay-jina.go index 5fdd44f..6c339ae 100644 --- a/relay/channel/jina/relay-jina.go +++ b/relay/channel/jina/relay-jina.go @@ -33,3 +33,28 @@ func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit _, err = c.Writer.Write(jsonResponse) return nil, &jinaResp.Usage } + +func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var jinaResp dto.OpenAIEmbeddingResponse + err = json.Unmarshal(responseBody, &jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + jsonResponse, err := json.Marshal(jinaResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &jinaResp.Usage +} diff --git a/relay/relay-text.go b/relay/relay-text.go index 3c5393a..14e82f1 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -52,7 +52,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) } case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeModerations: - if textRequest.Input == "" { + if textRequest.Input == "" || textRequest.Input == nil { return nil, errors.New("field input is required") } case relayconstant.RelayModeEdits: