mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 01:06:37 +08:00
Compare commits
2 Commits
e392fdbf60
...
12be617907
Author | SHA1 | Date | |
---|---|---|---|
|
12be617907 | ||
|
4907952fd1 |
@ -78,13 +78,19 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
model := meta.ActualModelName
|
||||
if strings.Contains(model, "?") {
|
||||
// TODO: Maybe fix meta.ActualModelName?
|
||||
model = strings.Split(model, "?")[0]
|
||||
}
|
||||
|
||||
if meta.BaseURL != "" {
|
||||
return fmt.Sprintf(
|
||||
"%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
meta.BaseURL,
|
||||
meta.Config.VertexAIProjectID,
|
||||
meta.Config.Region,
|
||||
meta.ActualModelName,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
@ -93,7 +99,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
meta.Config.Region,
|
||||
meta.Config.VertexAIProjectID,
|
||||
meta.Config.Region,
|
||||
meta.ActualModelName,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
||||
@ -23,11 +24,26 @@ var ModelList = []string{
|
||||
|
||||
type Adaptor struct {
|
||||
model string
|
||||
task EmbeddingTaskType
|
||||
}
|
||||
|
||||
var _ model2.InnerAIAdapter = (*Adaptor)(nil)
|
||||
|
||||
func (a *Adaptor) parseEmbeddingTaskType(model string) (string, EmbeddingTaskType) {
|
||||
modelTaskType := EmbeddingTaskTypeNone
|
||||
if strings.Contains(model, "?") {
|
||||
parts := strings.Split(model, "?")
|
||||
modelName := parts[0]
|
||||
if len(parts) >= 2 {
|
||||
modelOptions, err := url.ParseQuery(parts[1])
|
||||
if err == nil {
|
||||
modelTaskType = EmbeddingTaskType(modelOptions.Get("task_type"))
|
||||
}
|
||||
}
|
||||
return modelName, modelTaskType
|
||||
}
|
||||
return model, modelTaskType
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
@ -36,18 +52,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
if len(inputs) == 0 {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
parts := strings.Split(request.Model, "|")
|
||||
if len(parts) >= 2 {
|
||||
a.task = EmbeddingTaskType(parts[1])
|
||||
} else {
|
||||
a.task = EmbeddingTaskTypeSemanticSimilarity
|
||||
}
|
||||
a.model = parts[0]
|
||||
modelName, modelTaskType := a.parseEmbeddingTaskType(request.Model)
|
||||
a.model = modelName
|
||||
instances := make([]EmbeddingInstance, len(inputs))
|
||||
for i, input := range inputs {
|
||||
instances[i] = EmbeddingInstance{
|
||||
Content: input,
|
||||
TaskType: a.task,
|
||||
TaskType: modelTaskType,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,6 +3,7 @@ package vertexai
|
||||
type EmbeddingTaskType string
|
||||
|
||||
const (
|
||||
EmbeddingTaskTypeNone EmbeddingTaskType = ""
|
||||
EmbeddingTaskTypeRetrievalQuery EmbeddingTaskType = "RETRIEVAL_QUERY"
|
||||
EmbeddingTaskTypeRetrievalDocument EmbeddingTaskType = "RETRIEVAL_DOCUMENT"
|
||||
EmbeddingTaskTypeSemanticSimilarity EmbeddingTaskType = "SEMANTIC_SIMILARITY"
|
||||
|
Loading…
Reference in New Issue
Block a user