mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
feat: support ollama embedding
This commit is contained in:
parent
319e97d677
commit
8eedad9470
@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,7 +20,12 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
switch info.RelayMode {
|
||||||
|
case relayconstant.RelayModeEmbeddings:
|
||||||
|
return info.BaseUrl + "/api/embeddings", nil
|
||||||
|
default:
|
||||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
@ -31,7 +37,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
switch relayMode {
|
||||||
|
case relayconstant.RelayModeEmbeddings:
|
||||||
|
return requestOpenAI2Embeddings(*request), nil
|
||||||
|
default:
|
||||||
return requestOpenAI2Ollama(*request), nil
|
return requestOpenAI2Ollama(*request), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
@ -43,9 +54,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
|
err, usage, sensitiveResp = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||||
} else {
|
} else {
|
||||||
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,13 +6,21 @@ type OllamaRequest struct {
|
|||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Messages []dto.Message `json:"messages,omitempty"`
|
Messages []dto.Message `json:"messages,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Options *OllamaOptions `json:"options,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OllamaOptions struct {
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
Seed float64 `json:"seed,omitempty"`
|
Seed float64 `json:"seed,omitempty"`
|
||||||
Topp float64 `json:"top_p,omitempty"`
|
Topp float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
Stop any `json:"stop,omitempty"`
|
Stop any `json:"stop,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OllamaEmbeddingRequest struct {
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Prompt any `json:"prompt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OllamaEmbeddingResponse struct {
|
||||||
|
Embedding []float64 `json:"embedding,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
//type OllamaOptions struct {
|
||||||
|
//}
|
||||||
|
@ -1,7 +1,14 @@
|
|||||||
package ollama
|
package ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||||
@ -23,12 +30,79 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
|||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
Stream: request.Stream,
|
Stream: request.Stream,
|
||||||
Options: &OllamaOptions{
|
|
||||||
Temperature: request.Temperature,
|
Temperature: request.Temperature,
|
||||||
Seed: request.Seed,
|
Seed: request.Seed,
|
||||||
Topp: request.TopP,
|
Topp: request.TopP,
|
||||||
TopK: request.TopK,
|
TopK: request.TopK,
|
||||||
Stop: Stop,
|
Stop: Stop,
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
|
||||||
|
return &OllamaEmbeddingRequest{
|
||||||
|
Model: request.Model,
|
||||||
|
Prompt: request.Input,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
|
||||||
|
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||||
|
}
|
||||||
|
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||||
|
data = append(data, dto.OpenAIEmbeddingResponseItem{
|
||||||
|
Embedding: ollamaEmbeddingResponse.Embedding,
|
||||||
|
Object: "embedding",
|
||||||
|
})
|
||||||
|
usage := &dto.Usage{
|
||||||
|
TotalTokens: promptTokens,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
}
|
||||||
|
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
Model: model,
|
||||||
|
Usage: *usage,
|
||||||
|
}
|
||||||
|
doResponseBody, err := json.Marshal(embeddingResponse)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||||
|
}
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
|
||||||
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||||
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||||
|
// So the httpClient will be confused by the response.
|
||||||
|
// For example, Postman will report error, and we cannot check the response at all.
|
||||||
|
// Copy headers
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
// 删除任何现有的相同头部,以防止重复添加头部
|
||||||
|
c.Writer.Header().Del(k)
|
||||||
|
for _, vv := range v {
|
||||||
|
c.Writer.Header().Add(k, vv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// reset content length
|
||||||
|
c.Writer.Header().Del("Content-Length")
|
||||||
|
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
|
||||||
|
}
|
||||||
|
return nil, usage, nil
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user