Compare commits

..

1 Commits

Author SHA1 Message Date
RandyZhang
e392fdbf60
Merge 209a14c26f into 8df4a2670b 2025-05-10 15:17:24 +08:00
3 changed files with 11 additions and 29 deletions

View File

@ -78,19 +78,13 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
}
}
model := meta.ActualModelName
if strings.Contains(model, "?") {
// TODO: Maybe fix meta.ActualModelName?
model = strings.Split(model, "?")[0]
}
if meta.BaseURL != "" {
return fmt.Sprintf(
"%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
meta.BaseURL,
meta.Config.VertexAIProjectID,
meta.Config.Region,
model,
meta.ActualModelName,
suffix,
), nil
}
@ -99,7 +93,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
meta.Config.Region,
meta.Config.VertexAIProjectID,
meta.Config.Region,
model,
meta.ActualModelName,
suffix,
), nil
}

View File

@ -4,7 +4,6 @@ import (
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
@ -24,26 +23,11 @@ var ModelList = []string{
type Adaptor struct {
model string
task EmbeddingTaskType
}
var _ model2.InnerAIAdapter = (*Adaptor)(nil)
func (a *Adaptor) parseEmbeddingTaskType(model string) (string, EmbeddingTaskType) {
modelTaskType := EmbeddingTaskTypeNone
if strings.Contains(model, "?") {
parts := strings.Split(model, "?")
modelName := parts[0]
if len(parts) >= 2 {
modelOptions, err := url.ParseQuery(parts[1])
if err == nil {
modelTaskType = EmbeddingTaskType(modelOptions.Get("task_type"))
}
}
return modelName, modelTaskType
}
return model, modelTaskType
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
@ -52,13 +36,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if len(inputs) == 0 {
return nil, errors.New("request is nil")
}
modelName, modelTaskType := a.parseEmbeddingTaskType(request.Model)
a.model = modelName
parts := strings.Split(request.Model, "|")
if len(parts) >= 2 {
a.task = EmbeddingTaskType(parts[1])
} else {
a.task = EmbeddingTaskTypeSemanticSimilarity
}
a.model = parts[0]
instances := make([]EmbeddingInstance, len(inputs))
for i, input := range inputs {
instances[i] = EmbeddingInstance{
Content: input,
TaskType: modelTaskType,
TaskType: a.task,
}
}

View File

@ -3,7 +3,6 @@ package vertexai
type EmbeddingTaskType string
const (
EmbeddingTaskTypeNone EmbeddingTaskType = ""
EmbeddingTaskTypeRetrievalQuery EmbeddingTaskType = "RETRIEVAL_QUERY"
EmbeddingTaskTypeRetrievalDocument EmbeddingTaskType = "RETRIEVAL_DOCUMENT"
EmbeddingTaskTypeSemanticSimilarity EmbeddingTaskType = "SEMANTIC_SIMILARITY"