mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 01:06:37 +08:00
support vertex embedding task_type
This commit is contained in:
parent
209a14c26f
commit
4907952fd1
@ -78,13 +78,19 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model := meta.ActualModelName
|
||||||
|
if strings.Contains(model, "?") {
|
||||||
|
// TODO: Maybe fix meta.ActualModelName?
|
||||||
|
model = strings.Split(model, "?")[0]
|
||||||
|
}
|
||||||
|
|
||||||
if meta.BaseURL != "" {
|
if meta.BaseURL != "" {
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
"%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||||
meta.BaseURL,
|
meta.BaseURL,
|
||||||
meta.Config.VertexAIProjectID,
|
meta.Config.VertexAIProjectID,
|
||||||
meta.Config.Region,
|
meta.Config.Region,
|
||||||
meta.ActualModelName,
|
model,
|
||||||
suffix,
|
suffix,
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
@ -93,7 +99,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
|||||||
meta.Config.Region,
|
meta.Config.Region,
|
||||||
meta.Config.VertexAIProjectID,
|
meta.Config.VertexAIProjectID,
|
||||||
meta.Config.Region,
|
meta.Config.Region,
|
||||||
meta.ActualModelName,
|
model,
|
||||||
suffix,
|
suffix,
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
||||||
@ -23,11 +24,26 @@ var ModelList = []string{
|
|||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
model string
|
model string
|
||||||
task EmbeddingTaskType
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model2.InnerAIAdapter = (*Adaptor)(nil)
|
var _ model2.InnerAIAdapter = (*Adaptor)(nil)
|
||||||
|
|
||||||
|
func (a *Adaptor) parseEmbeddingTaskType(model string) (string, EmbeddingTaskType) {
|
||||||
|
modelTaskType := EmbeddingTaskTypeNone
|
||||||
|
if strings.Contains(model, "?") {
|
||||||
|
parts := strings.Split(model, "?")
|
||||||
|
modelName := parts[0]
|
||||||
|
if len(parts) >= 2 {
|
||||||
|
modelOptions, err := url.ParseQuery(parts[1])
|
||||||
|
if err == nil {
|
||||||
|
modelTaskType = EmbeddingTaskType(modelOptions.Get("task_type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modelName, modelTaskType
|
||||||
|
}
|
||||||
|
return model, modelTaskType
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
@ -36,18 +52,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
if len(inputs) == 0 {
|
if len(inputs) == 0 {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
parts := strings.Split(request.Model, "|")
|
modelName, modelTaskType := a.parseEmbeddingTaskType(request.Model)
|
||||||
if len(parts) >= 2 {
|
a.model = modelName
|
||||||
a.task = EmbeddingTaskType(parts[1])
|
|
||||||
} else {
|
|
||||||
a.task = EmbeddingTaskTypeSemanticSimilarity
|
|
||||||
}
|
|
||||||
a.model = parts[0]
|
|
||||||
instances := make([]EmbeddingInstance, len(inputs))
|
instances := make([]EmbeddingInstance, len(inputs))
|
||||||
for i, input := range inputs {
|
for i, input := range inputs {
|
||||||
instances[i] = EmbeddingInstance{
|
instances[i] = EmbeddingInstance{
|
||||||
Content: input,
|
Content: input,
|
||||||
TaskType: a.task,
|
TaskType: modelTaskType,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package vertexai
|
|||||||
type EmbeddingTaskType string
|
type EmbeddingTaskType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
EmbeddingTaskTypeNone EmbeddingTaskType = ""
|
||||||
EmbeddingTaskTypeRetrievalQuery EmbeddingTaskType = "RETRIEVAL_QUERY"
|
EmbeddingTaskTypeRetrievalQuery EmbeddingTaskType = "RETRIEVAL_QUERY"
|
||||||
EmbeddingTaskTypeRetrievalDocument EmbeddingTaskType = "RETRIEVAL_DOCUMENT"
|
EmbeddingTaskTypeRetrievalDocument EmbeddingTaskType = "RETRIEVAL_DOCUMENT"
|
||||||
EmbeddingTaskTypeSemanticSimilarity EmbeddingTaskType = "SEMANTIC_SIMILARITY"
|
EmbeddingTaskTypeSemanticSimilarity EmbeddingTaskType = "SEMANTIC_SIMILARITY"
|
||||||
|
Loading…
Reference in New Issue
Block a user