feat: support ollama multi-text embedding

This commit is contained in:
CalciumIon 2024-09-12 18:29:45 +08:00
parent 1675679be9
commit 0cbf8e07e7
2 changed files with 28 additions and 5 deletions

View File

@ -17,11 +17,25 @@ type OllamaRequest struct {
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
} }
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type OllamaEmbeddingRequest struct { type OllamaEmbeddingRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Prompt any `json:"prompt,omitempty"` Input []string `json:"input"`
Options *Options `json:"options,omitempty"`
} }
type OllamaEmbeddingResponse struct { type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding []float64 `json:"embedding,omitempty"` Embedding []float64 `json:"embedding,omitempty"`
} }

View File

@ -9,7 +9,6 @@ import (
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/service" "one-api/service"
"strings"
) )
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
@ -45,8 +44,15 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{ return &OllamaEmbeddingRequest{
Model: request.Model, Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "), Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
} }
} }
@ -64,6 +70,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if ollamaEmbeddingResponse.Error != "" {
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
}
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{ data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: ollamaEmbeddingResponse.Embedding, Embedding: ollamaEmbeddingResponse.Embedding,