From 8eedad9470ee723f7cd8fe8229f77c39f934b8cb Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Tue, 26 Mar 2024 19:53:53 +0800 Subject: [PATCH] feat: support ollama embedding --- relay/channel/ollama/adaptor.go | 21 ++++++- relay/channel/ollama/dto.go | 28 ++++++--- relay/channel/ollama/relay-ollama.go | 94 +++++++++++++++++++++++++--- 3 files changed, 120 insertions(+), 23 deletions(-) diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 4c17252..f66d9a9 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" "one-api/service" ) @@ -19,7 +20,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + switch info.RelayMode { + case relayconstant.RelayModeEmbeddings: + return info.BaseUrl + "/api/embeddings", nil + default: + return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { @@ -31,7 +37,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen if request == nil { return nil, errors.New("request is nil") } - return requestOpenAI2Ollama(*request), nil + switch relayMode { + case relayconstant.RelayModeEmbeddings: + return requestOpenAI2Embeddings(*request), nil + default: + return requestOpenAI2Ollama(*request), nil + } } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { @@ -44,7 +55,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + if info.RelayMode == relayconstant.RelayModeEmbeddings { + err, usage, sensitiveResp = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + } else { + err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) + } } return } diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index a43fb84..a6d6238 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -3,16 +3,24 @@ package ollama import "one-api/dto" type OllamaRequest struct { - Model string `json:"model,omitempty"` - Messages []dto.Message `json:"messages,omitempty"` - Stream bool `json:"stream,omitempty"` - Options *OllamaOptions `json:"options,omitempty"` + Model string `json:"model,omitempty"` + Messages []dto.Message `json:"messages,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Seed float64 `json:"seed,omitempty"` + Topp float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Stop any `json:"stop,omitempty"` } -type OllamaOptions struct { - Temperature float64 `json:"temperature,omitempty"` - Seed float64 `json:"seed,omitempty"` - Topp float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` +type OllamaEmbeddingRequest struct { + Model string `json:"model,omitempty"` + Prompt any `json:"prompt,omitempty"` } + +type OllamaEmbeddingResponse struct { + Embedding []float64 `json:"embedding,omitempty"` +} + +//type OllamaOptions struct { +//} diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 41a2a15..fa5f818 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -1,7 +1,14 @@ package ollama import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" "one-api/dto" + "one-api/service" ) func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { @@ -20,15 +27,82 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { Stop, _ = request.Stop.([]string) } return &OllamaRequest{ - Model: request.Model, - Messages: messages, - Stream: request.Stream, - Options: &OllamaOptions{ - Temperature: request.Temperature, - Seed: request.Seed, - Topp: request.TopP, - TopK: request.TopK, - Stop: Stop, - }, + Model: request.Model, + Messages: messages, + Stream: request.Stream, + Temperature: request.Temperature, + Seed: request.Seed, + Topp: request.TopP, + TopK: request.TopK, + Stop: Stop, } } + +func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { + return &OllamaEmbeddingRequest{ + Model: request.Model, + Prompt: request.Input, + } +} + +func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) { + var ollamaEmbeddingResponse OllamaEmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + } + err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil + } + data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) + data = append(data, dto.OpenAIEmbeddingResponseItem{ + Embedding: ollamaEmbeddingResponse.Embedding, + Object: "embedding", + }) + usage := &dto.Usage{ + TotalTokens: promptTokens, + CompletionTokens: 0, + PromptTokens: promptTokens, + } + embeddingResponse := &dto.OpenAIEmbeddingResponse{ + Object: "list", + Data: data, + Model: model, + Usage: *usage, + } + doResponseBody, err := json.Marshal(embeddingResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil, nil + } + resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + // Copy headers + for k, v := range resp.Header { + // 删除任何现有的相同头部,以防止重复添加头部 + c.Writer.Header().Del(k) + for _, vv := range v { + c.Writer.Header().Add(k, vv) + } + } + // reset content length + c.Writer.Header().Del("Content-Length") + c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody))) + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil + } + return nil, usage, nil +}