From 209a14c26ff78ac8a6bbe92014475b421710732a Mon Sep 17 00:00:00 2001 From: RandyZhang Date: Sat, 10 May 2025 15:09:05 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0vertex=20embedding=E7=9A=84?= =?UTF-8?q?=E6=94=AF=E6=8C=81=EF=BC=8C=E4=BF=AE=E6=94=B9vertex=E7=9A=84?= =?UTF-8?q?=E6=A8=A1=E5=9E=8Badapter=E5=8C=B9=E9=85=8D=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/adaptor/gemini/main.go | 14 +++ relay/adaptor/vertexai/adaptor.go | 15 ++- relay/adaptor/vertexai/embedding/adapter.go | 107 ++++++++++++++++++++ relay/adaptor/vertexai/embedding/model.go | 45 ++++++++ relay/adaptor/vertexai/model/model.go | 13 +++ relay/adaptor/vertexai/registry.go | 41 ++++---- 6 files changed, 217 insertions(+), 18 deletions(-) create mode 100644 relay/adaptor/vertexai/embedding/adapter.go create mode 100644 relay/adaptor/vertexai/embedding/model.go create mode 100644 relay/adaptor/vertexai/model/model.go diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 29637296..579ba30a 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -435,3 +435,17 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func EmbeddingResponseHandler(c *gin.Context, statusCode int, resp *openai.EmbeddingResponse) (*model.ErrorWithStatusCode, *model.Usage) { + jsonResponse, err := json.Marshal(resp) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(statusCode) + _, err = c.Writer.Write(jsonResponse) + if err != nil { + return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &resp.Usage +} diff --git a/relay/adaptor/vertexai/adaptor.go b/relay/adaptor/vertexai/adaptor.go index 3fab4a45..afae18a1 100644 --- a/relay/adaptor/vertexai/adaptor.go +++ b/relay/adaptor/vertexai/adaptor.go @@ -61,12 +61,15 @@ func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { suffix := "" - if strings.HasPrefix(meta.ActualModelName, "gemini") { + modelType := PredictModelType(meta.ActualModelName) + if modelType == VertexAIGemini { if meta.IsStream { suffix = "streamGenerateContent?alt=sse" } else { suffix = "generateContent" } + } else if modelType == VertexAIEmbedding { + suffix = "predict" } else { if meta.IsStream { suffix = "streamRawPredict?alt=sse" @@ -115,3 +118,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } + +func PredictModelType(model string) VertexAIModelType { + if strings.HasPrefix(model, "gemini-") { + return VertexAIGemini + } + if strings.HasPrefix(model, "text-embedding") || strings.HasPrefix(model, "text-multilingual-embedding") { + return VertexAIEmbedding + } + return VertexAIClaude +} diff --git a/relay/adaptor/vertexai/embedding/adapter.go b/relay/adaptor/vertexai/embedding/adapter.go new file mode 100644 index 00000000..17004928 --- /dev/null +++ b/relay/adaptor/vertexai/embedding/adapter.go @@ -0,0 +1,107 @@ +package vertexai + +import ( + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/songquanpeng/one-api/relay/adaptor/gemini" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + model2 "github.com/songquanpeng/one-api/relay/adaptor/vertexai/model" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var ModelList = []string{ + "textembedding-gecko-multilingual@001", "text-multilingual-embedding-002", +} + +type Adaptor struct { + model string + task EmbeddingTaskType +} + +var _ model2.InnerAIAdapter = (*Adaptor)(nil) + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + inputs := request.ParseInput() + if len(inputs) == 0 { + return nil, errors.New("request is nil") + } + parts := strings.Split(request.Model, "|") + if len(parts) >= 2 { + a.task = EmbeddingTaskType(parts[1]) + } else { + a.task = EmbeddingTaskTypeSemanticSimilarity + } + a.model = parts[0] + instances := make([]EmbeddingInstance, len(inputs)) + for i, input := range inputs { + instances[i] = EmbeddingInstance{ + Content: input, + TaskType: a.task, + } + } + + embeddingRequest := EmbeddingRequest{ + Instances: instances, + Parameters: EmbeddingParams{ + OutputDimensionality: request.Dimensions, + }, + } + + return embeddingRequest, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + err, usage = EmbeddingHandler(c, a.model, resp) + return +} + +func EmbeddingHandler(c *gin.Context, modelName string, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var vertexEmbeddingResponse EmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &vertexEmbeddingResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + openaiResp := &openai.EmbeddingResponse{ + Model: modelName, + Data: make([]openai.EmbeddingResponseItem, 0, len(vertexEmbeddingResponse.Predictions)), + Usage: model.Usage{ + TotalTokens: 0, + }, + } + + for i, pred := range vertexEmbeddingResponse.Predictions { + openaiResp.Data = append(openaiResp.Data, openai.EmbeddingResponseItem{ + Index: i, + Embedding: pred.Embeddings.Values, + }) + } + + for _, pred := range vertexEmbeddingResponse.Predictions { + openaiResp.Usage.TotalTokens += pred.Embeddings.Statistics.TokenCount + } + + return gemini.EmbeddingResponseHandler(c, resp.StatusCode, openaiResp) +} diff --git a/relay/adaptor/vertexai/embedding/model.go b/relay/adaptor/vertexai/embedding/model.go new file mode 100644 index 00000000..7e5f86fa --- /dev/null +++ b/relay/adaptor/vertexai/embedding/model.go @@ -0,0 +1,45 @@ +package vertexai + +type EmbeddingTaskType string + +const ( + EmbeddingTaskTypeRetrievalQuery EmbeddingTaskType = "RETRIEVAL_QUERY" + EmbeddingTaskTypeRetrievalDocument EmbeddingTaskType = "RETRIEVAL_DOCUMENT" + EmbeddingTaskTypeSemanticSimilarity EmbeddingTaskType = "SEMANTIC_SIMILARITY" + EmbeddingTaskTypeClassification EmbeddingTaskType = "CLASSIFICATION" + EmbeddingTaskTypeClustering EmbeddingTaskType = "CLUSTERING" + EmbeddingTaskTypeQuestionAnswering EmbeddingTaskType = "QUESTION_ANSWERING" + EmbeddingTaskTypeFactVerification EmbeddingTaskType = "FACT_VERIFICATION" + EmbeddingTaskTypeCodeRetrievalQuery EmbeddingTaskType = "CODE_RETRIEVAL_QUERY" +) + +type EmbeddingRequest struct { + Instances []EmbeddingInstance `json:"instances"` + Parameters EmbeddingParams `json:"parameters"` +} + +type EmbeddingInstance struct { + Content string `json:"content"` + TaskType EmbeddingTaskType `json:"task_type,omitempty"` + Title string `json:"title,omitempty"` +} + +type EmbeddingParams struct { + AutoTruncate bool `json:"autoTruncate,omitempty"` + OutputDimensionality int `json:"outputDimensionality,omitempty"` + // Texts []string `json:"texts,omitempty"` +} + +type EmbeddingResponse struct { + Predictions []struct { + Embeddings EmbeddingData `json:"embeddings"` + } `json:"predictions"` +} + +type EmbeddingData struct { + Statistics struct { + Truncated bool `json:"truncated"` + TokenCount int `json:"token_count"` + } `json:"statistics"` + Values []float64 `json:"values"` +} diff --git a/relay/adaptor/vertexai/model/model.go b/relay/adaptor/vertexai/model/model.go new file mode 100644 index 00000000..a39d38c2 --- /dev/null +++ b/relay/adaptor/vertexai/model/model.go @@ -0,0 +1,13 @@ +package model + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "net/http" +) + +type InnerAIAdapter interface { + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) +} diff --git a/relay/adaptor/vertexai/registry.go b/relay/adaptor/vertexai/registry.go index 41099f02..bc2fef50 100644 --- a/relay/adaptor/vertexai/registry.go +++ b/relay/adaptor/vertexai/registry.go @@ -1,20 +1,18 @@ package vertexai import ( - "net/http" - - "github.com/gin-gonic/gin" claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude" + embedding "github.com/songquanpeng/one-api/relay/adaptor/vertexai/embedding" gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini" - "github.com/songquanpeng/one-api/relay/meta" - "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/adaptor/vertexai/model" ) type VertexAIModelType int const ( - VerterAIClaude VertexAIModelType = iota + 1 - VerterAIGemini + VertexAIClaude VertexAIModelType = iota + 1 + VertexAIGemini + VertexAIEmbedding ) var modelMapping = map[string]VertexAIModelType{} @@ -23,28 +21,37 @@ var modelList = []string{} func init() { modelList = append(modelList, claude.ModelList...) for _, model := range claude.ModelList { - modelMapping[model] = VerterAIClaude + modelMapping[model] = VertexAIClaude } modelList = append(modelList, gemini.ModelList...) for _, model := range gemini.ModelList { - modelMapping[model] = VerterAIGemini + modelMapping[model] = VertexAIGemini + } + + modelList = append(modelList, embedding.ModelList...) + for _, model := range embedding.ModelList { + modelMapping[model] = VertexAIEmbedding } } -type innerAIAdapter interface { - ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) - DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) -} - -func GetAdaptor(model string) innerAIAdapter { +func GetAdaptor(model string) model.InnerAIAdapter { adaptorType := modelMapping[model] switch adaptorType { - case VerterAIClaude: + case VertexAIClaude: return &claude.Adaptor{} - case VerterAIGemini: + case VertexAIGemini: return &gemini.Adaptor{} + case VertexAIEmbedding: + return &embedding.Adaptor{} default: + adaptorType = PredictModelType(model) + switch adaptorType { + case VertexAIGemini: + return &gemini.Adaptor{} + case VertexAIEmbedding: + return &embedding.Adaptor{} + } return nil } }