diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index fac6b7f..ed4d5ca 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -17,11 +17,25 @@ type OllamaRequest struct { PresencePenalty float64 `json:"presence_penalty,omitempty"` } +type Options struct { + Seed int `json:"seed,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` +} + type OllamaEmbeddingRequest struct { - Model string `json:"model,omitempty"` - Prompt any `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + Input []string `json:"input"` + Options *Options `json:"options,omitempty"` } type OllamaEmbeddingResponse struct { + Error string `json:"error,omitempty"` + Model string `json:"model"` Embedding []float64 `json:"embedding,omitempty"` } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 6bf395a..b2d4630 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -9,7 +9,6 @@ import ( "net/http" "one-api/dto" "one-api/service" - "strings" ) func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { @@ -45,8 +44,15 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { return &OllamaEmbeddingRequest{ - Model: request.Model, - Prompt: strings.Join(request.ParseInput(), " "), + Model: request.Model, + Input: request.ParseInput(), + Options: &Options{ + Seed: int(request.Seed), + Temperature: request.Temperature, + TopP: request.TopP, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + }, } } @@ -64,6 +70,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } + if ollamaEmbeddingResponse.Error != "" { + return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil + } data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) data = append(data, dto.OpenAIEmbeddingResponseItem{ Embedding: ollamaEmbeddingResponse.Embedding,