feat: support ollama embedding

This commit is contained in:
CaIon 2024-03-26 19:53:53 +08:00
parent 319e97d677
commit 8eedad9470
3 changed files with 120 additions and 23 deletions

View File

@ -9,6 +9,7 @@ import (
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
) )
@ -19,7 +20,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return info.BaseUrl + "/api/embeddings", nil
default:
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@ -31,7 +37,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
return requestOpenAI2Ollama(*request), nil switch relayMode {
case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil
default:
return requestOpenAI2Ollama(*request), nil
}
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
@ -44,7 +55,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage, sensitiveResp = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
}
} }
return return
} }

View File

@ -3,16 +3,24 @@ package ollama
import "one-api/dto" import "one-api/dto"
type OllamaRequest struct { type OllamaRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"` Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Options *OllamaOptions `json:"options,omitempty"` Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
} }
type OllamaOptions struct { type OllamaEmbeddingRequest struct {
Temperature float64 `json:"temperature,omitempty"` Model string `json:"model,omitempty"`
Seed float64 `json:"seed,omitempty"` Prompt any `json:"prompt,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
} }
type OllamaEmbeddingResponse struct {
Embedding []float64 `json:"embedding,omitempty"`
}
//type OllamaOptions struct {
//}

View File

@ -1,7 +1,14 @@
package ollama package ollama
import ( import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto" "one-api/dto"
"one-api/service"
) )
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
@ -20,15 +27,82 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
Stop, _ = request.Stop.([]string) Stop, _ = request.Stop.([]string)
} }
return &OllamaRequest{ return &OllamaRequest{
Model: request.Model, Model: request.Model,
Messages: messages, Messages: messages,
Stream: request.Stream, Stream: request.Stream,
Options: &OllamaOptions{ Temperature: request.Temperature,
Temperature: request.Temperature, Seed: request.Seed,
Seed: request.Seed, Topp: request.TopP,
Topp: request.TopP, TopK: request.TopK,
TopK: request.TopK, Stop: Stop,
Stop: Stop,
},
} }
} }
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Prompt: request.Input,
}
}
func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
}
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: ollamaEmbeddingResponse.Embedding,
Object: "embedding",
})
usage := &dto.Usage{
TotalTokens: promptTokens,
CompletionTokens: 0,
PromptTokens: promptTokens,
}
embeddingResponse := &dto.OpenAIEmbeddingResponse{
Object: "list",
Data: data,
Model: model,
Usage: *usage,
}
doResponseBody, err := json.Marshal(embeddingResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil, nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
// Copy headers
for k, v := range resp.Header {
// 删除任何现有的相同头部,以防止重复添加头部
c.Writer.Header().Del(k)
for _, vv := range v {
c.Writer.Header().Add(k, vv)
}
}
// reset content length
c.Writer.Header().Del("Content-Length")
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
}
return nil, usage, nil
}