Compare commits

...

2 Commits

Author SHA1 Message Date
RandyZhang
12be617907
Merge 4907952fd1 into 8df4a2670b 2025-05-12 10:24:54 +00:00
RandyZhang
4907952fd1 support vertex embedding task_type 2025-05-12 18:24:42 +08:00
3 changed files with 29 additions and 11 deletions

View File

@ -78,13 +78,19 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
}
}
model := meta.ActualModelName
if strings.Contains(model, "?") {
// TODO: Maybe fix meta.ActualModelName?
model = strings.Split(model, "?")[0]
}
if meta.BaseURL != "" {
return fmt.Sprintf(
"%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
meta.BaseURL,
meta.Config.VertexAIProjectID,
meta.Config.Region,
meta.ActualModelName,
model,
suffix,
), nil
}
@ -93,7 +99,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
meta.Config.Region,
meta.Config.VertexAIProjectID,
meta.Config.Region,
meta.ActualModelName,
model,
suffix,
), nil
}

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
@ -23,11 +24,26 @@ var ModelList = []string{
type Adaptor struct {
model string
task EmbeddingTaskType
}
var _ model2.InnerAIAdapter = (*Adaptor)(nil)
func (a *Adaptor) parseEmbeddingTaskType(model string) (string, EmbeddingTaskType) {
modelTaskType := EmbeddingTaskTypeNone
if strings.Contains(model, "?") {
parts := strings.Split(model, "?")
modelName := parts[0]
if len(parts) >= 2 {
modelOptions, err := url.ParseQuery(parts[1])
if err == nil {
modelTaskType = EmbeddingTaskType(modelOptions.Get("task_type"))
}
}
return modelName, modelTaskType
}
return model, modelTaskType
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
@ -36,18 +52,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if len(inputs) == 0 {
return nil, errors.New("request is nil")
}
parts := strings.Split(request.Model, "|")
if len(parts) >= 2 {
a.task = EmbeddingTaskType(parts[1])
} else {
a.task = EmbeddingTaskTypeSemanticSimilarity
}
a.model = parts[0]
modelName, modelTaskType := a.parseEmbeddingTaskType(request.Model)
a.model = modelName
instances := make([]EmbeddingInstance, len(inputs))
for i, input := range inputs {
instances[i] = EmbeddingInstance{
Content: input,
TaskType: a.task,
TaskType: modelTaskType,
}
}

View File

@ -3,6 +3,7 @@ package vertexai
type EmbeddingTaskType string
const (
EmbeddingTaskTypeNone EmbeddingTaskType = ""
EmbeddingTaskTypeRetrievalQuery EmbeddingTaskType = "RETRIEVAL_QUERY"
EmbeddingTaskTypeRetrievalDocument EmbeddingTaskType = "RETRIEVAL_DOCUMENT"
EmbeddingTaskTypeSemanticSimilarity EmbeddingTaskType = "SEMANTIC_SIMILARITY"