From 4907952fd13dd7640f3e89e260b1446b733f31da Mon Sep 17 00:00:00 2001 From: RandyZhang Date: Mon, 12 May 2025 18:23:40 +0800 Subject: [PATCH] support vertex embedding task_type --- relay/adaptor/vertexai/adaptor.go | 10 +++++-- relay/adaptor/vertexai/embedding/adapter.go | 29 ++++++++++++++------- relay/adaptor/vertexai/embedding/model.go | 1 + 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/relay/adaptor/vertexai/adaptor.go b/relay/adaptor/vertexai/adaptor.go index afae18a1..ae0d8c9d 100644 --- a/relay/adaptor/vertexai/adaptor.go +++ b/relay/adaptor/vertexai/adaptor.go @@ -78,13 +78,19 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { } } + model := meta.ActualModelName + if strings.Contains(model, "?") { + // TODO: Maybe fix meta.ActualModelName? + model = strings.Split(model, "?")[0] + } + if meta.BaseURL != "" { return fmt.Sprintf( "%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", meta.BaseURL, meta.Config.VertexAIProjectID, meta.Config.Region, - meta.ActualModelName, + model, suffix, ), nil } @@ -93,7 +99,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { meta.Config.Region, meta.Config.VertexAIProjectID, meta.Config.Region, - meta.ActualModelName, + model, suffix, ), nil } diff --git a/relay/adaptor/vertexai/embedding/adapter.go b/relay/adaptor/vertexai/embedding/adapter.go index 17004928..47cd88fa 100644 --- a/relay/adaptor/vertexai/embedding/adapter.go +++ b/relay/adaptor/vertexai/embedding/adapter.go @@ -4,6 +4,7 @@ import ( "encoding/json" "io" "net/http" + "net/url" "strings" "github.com/songquanpeng/one-api/relay/adaptor/gemini" @@ -23,11 +24,26 @@ var ModelList = []string{ type Adaptor struct { model string - task EmbeddingTaskType } var _ model2.InnerAIAdapter = (*Adaptor)(nil) +func (a *Adaptor) parseEmbeddingTaskType(model string) (string, EmbeddingTaskType) { + modelTaskType := EmbeddingTaskTypeNone + if strings.Contains(model, "?") { + parts := strings.Split(model, "?") + modelName := parts[0] + if len(parts) >= 2 { + modelOptions, err := url.ParseQuery(parts[1]) + if err == nil { + modelTaskType = EmbeddingTaskType(modelOptions.Get("task_type")) + } + } + return modelName, modelTaskType + } + return model, modelTaskType +} + func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") @@ -36,18 +52,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if len(inputs) == 0 { return nil, errors.New("request is nil") } - parts := strings.Split(request.Model, "|") - if len(parts) >= 2 { - a.task = EmbeddingTaskType(parts[1]) - } else { - a.task = EmbeddingTaskTypeSemanticSimilarity - } - a.model = parts[0] + modelName, modelTaskType := a.parseEmbeddingTaskType(request.Model) + a.model = modelName instances := make([]EmbeddingInstance, len(inputs)) for i, input := range inputs { instances[i] = EmbeddingInstance{ Content: input, - TaskType: a.task, + TaskType: modelTaskType, } } diff --git a/relay/adaptor/vertexai/embedding/model.go b/relay/adaptor/vertexai/embedding/model.go index 7e5f86fa..e9079569 100644 --- a/relay/adaptor/vertexai/embedding/model.go +++ b/relay/adaptor/vertexai/embedding/model.go @@ -3,6 +3,7 @@ package vertexai type EmbeddingTaskType string const ( + EmbeddingTaskTypeNone EmbeddingTaskType = "" EmbeddingTaskTypeRetrievalQuery EmbeddingTaskType = "RETRIEVAL_QUERY" EmbeddingTaskTypeRetrievalDocument EmbeddingTaskType = "RETRIEVAL_DOCUMENT" EmbeddingTaskTypeSemanticSimilarity EmbeddingTaskType = "SEMANTIC_SIMILARITY"