mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 09:16:36 +08:00
Compare commits
1 Commits
12be617907
...
e392fdbf60
Author | SHA1 | Date | |
---|---|---|---|
|
e392fdbf60 |
@ -78,19 +78,13 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
model := meta.ActualModelName
|
||||
if strings.Contains(model, "?") {
|
||||
// TODO: Maybe fix meta.ActualModelName?
|
||||
model = strings.Split(model, "?")[0]
|
||||
}
|
||||
|
||||
if meta.BaseURL != "" {
|
||||
return fmt.Sprintf(
|
||||
"%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
meta.BaseURL,
|
||||
meta.Config.VertexAIProjectID,
|
||||
meta.Config.Region,
|
||||
model,
|
||||
meta.ActualModelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
@ -99,7 +93,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
meta.Config.Region,
|
||||
meta.Config.VertexAIProjectID,
|
||||
meta.Config.Region,
|
||||
model,
|
||||
meta.ActualModelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
||||
@ -24,26 +23,11 @@ var ModelList = []string{
|
||||
|
||||
type Adaptor struct {
|
||||
model string
|
||||
task EmbeddingTaskType
|
||||
}
|
||||
|
||||
var _ model2.InnerAIAdapter = (*Adaptor)(nil)
|
||||
|
||||
func (a *Adaptor) parseEmbeddingTaskType(model string) (string, EmbeddingTaskType) {
|
||||
modelTaskType := EmbeddingTaskTypeNone
|
||||
if strings.Contains(model, "?") {
|
||||
parts := strings.Split(model, "?")
|
||||
modelName := parts[0]
|
||||
if len(parts) >= 2 {
|
||||
modelOptions, err := url.ParseQuery(parts[1])
|
||||
if err == nil {
|
||||
modelTaskType = EmbeddingTaskType(modelOptions.Get("task_type"))
|
||||
}
|
||||
}
|
||||
return modelName, modelTaskType
|
||||
}
|
||||
return model, modelTaskType
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
@ -52,13 +36,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
if len(inputs) == 0 {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
modelName, modelTaskType := a.parseEmbeddingTaskType(request.Model)
|
||||
a.model = modelName
|
||||
parts := strings.Split(request.Model, "|")
|
||||
if len(parts) >= 2 {
|
||||
a.task = EmbeddingTaskType(parts[1])
|
||||
} else {
|
||||
a.task = EmbeddingTaskTypeSemanticSimilarity
|
||||
}
|
||||
a.model = parts[0]
|
||||
instances := make([]EmbeddingInstance, len(inputs))
|
||||
for i, input := range inputs {
|
||||
instances[i] = EmbeddingInstance{
|
||||
Content: input,
|
||||
TaskType: modelTaskType,
|
||||
TaskType: a.task,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,6 @@ package vertexai
|
||||
type EmbeddingTaskType string
|
||||
|
||||
const (
|
||||
EmbeddingTaskTypeNone EmbeddingTaskType = ""
|
||||
EmbeddingTaskTypeRetrievalQuery EmbeddingTaskType = "RETRIEVAL_QUERY"
|
||||
EmbeddingTaskTypeRetrievalDocument EmbeddingTaskType = "RETRIEVAL_DOCUMENT"
|
||||
EmbeddingTaskTypeSemanticSimilarity EmbeddingTaskType = "SEMANTIC_SIMILARITY"
|
||||
|
Loading…
Reference in New Issue
Block a user