diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 29637296..579ba30a 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -435,3 +435,17 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } + +func EmbeddingResponseHandler(c *gin.Context, statusCode int, resp *openai.EmbeddingResponse) (*model.ErrorWithStatusCode, *model.Usage) { + jsonResponse, err := json.Marshal(resp) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(statusCode) + _, err = c.Writer.Write(jsonResponse) + if err != nil { + return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &resp.Usage +} diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index c3acae60..96dee113 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -65,13 +65,19 @@ type ChatTools struct { FunctionDeclarations any `json:"function_declarations,omitempty"` } -type ChatGenerationConfig struct { - ResponseMimeType string `json:"responseMimeType,omitempty"` - ResponseSchema any `json:"responseSchema,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` +type ThinkingConfig struct { + IncludeThoughts bool `json:"includeThoughts"` + ThinkingBudget int `json:"thinkingBudget"` +} + +type ChatGenerationConfig struct { + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema any `json:"responseSchema,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + ThinkingConfig *ThinkingConfig `json:"thinkingConfig"` } diff --git a/relay/adaptor/vertexai/adaptor.go b/relay/adaptor/vertexai/adaptor.go index 3fab4a45..ae0d8c9d 100644 --- a/relay/adaptor/vertexai/adaptor.go +++ b/relay/adaptor/vertexai/adaptor.go @@ -61,12 +61,15 @@ func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { suffix := "" - if strings.HasPrefix(meta.ActualModelName, "gemini") { + modelType := PredictModelType(meta.ActualModelName) + if modelType == VertexAIGemini { if meta.IsStream { suffix = "streamGenerateContent?alt=sse" } else { suffix = "generateContent" } + } else if modelType == VertexAIEmbedding { + suffix = "predict" } else { if meta.IsStream { suffix = "streamRawPredict?alt=sse" @@ -75,13 +78,19 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { } } + model := meta.ActualModelName + if strings.Contains(model, "?") { + // TODO: Maybe fix meta.ActualModelName? + model = strings.Split(model, "?")[0] + } + if meta.BaseURL != "" { return fmt.Sprintf( "%s/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", meta.BaseURL, meta.Config.VertexAIProjectID, meta.Config.Region, - meta.ActualModelName, + model, suffix, ), nil } @@ -90,7 +99,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { meta.Config.Region, meta.Config.VertexAIProjectID, meta.Config.Region, - meta.ActualModelName, + model, suffix, ), nil } @@ -115,3 +124,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } + +func PredictModelType(model string) VertexAIModelType { + if strings.HasPrefix(model, "gemini-") { + return VertexAIGemini + } + if strings.HasPrefix(model, "text-embedding") || strings.HasPrefix(model, "text-multilingual-embedding") { + return VertexAIEmbedding + } + return VertexAIClaude +} diff --git a/relay/adaptor/vertexai/embedding/adapter.go b/relay/adaptor/vertexai/embedding/adapter.go new file mode 100644 index 00000000..47cd88fa --- /dev/null +++ b/relay/adaptor/vertexai/embedding/adapter.go @@ -0,0 +1,118 @@ +package vertexai + +import ( + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + + "github.com/songquanpeng/one-api/relay/adaptor/gemini" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + model2 "github.com/songquanpeng/one-api/relay/adaptor/vertexai/model" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var ModelList = []string{ + "textembedding-gecko-multilingual@001", "text-multilingual-embedding-002", +} + +type Adaptor struct { + model string +} + +var _ model2.InnerAIAdapter = (*Adaptor)(nil) + +func (a *Adaptor) parseEmbeddingTaskType(model string) (string, EmbeddingTaskType) { + modelTaskType := EmbeddingTaskTypeNone + if strings.Contains(model, "?") { + parts := strings.Split(model, "?") + modelName := parts[0] + if len(parts) >= 2 { + modelOptions, err := url.ParseQuery(parts[1]) + if err == nil { + modelTaskType = EmbeddingTaskType(modelOptions.Get("task_type")) + } + } + return modelName, modelTaskType + } + return model, modelTaskType +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + inputs := request.ParseInput() + if len(inputs) == 0 { + return nil, errors.New("request is nil") + } + modelName, modelTaskType := a.parseEmbeddingTaskType(request.Model) + a.model = modelName + instances := make([]EmbeddingInstance, len(inputs)) + for i, input := range inputs { + instances[i] = EmbeddingInstance{ + Content: input, + TaskType: modelTaskType, + } + } + + embeddingRequest := EmbeddingRequest{ + Instances: instances, + Parameters: EmbeddingParams{ + OutputDimensionality: request.Dimensions, + }, + } + + return embeddingRequest, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + err, usage = EmbeddingHandler(c, a.model, resp) + return +} + +func EmbeddingHandler(c *gin.Context, modelName string, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var vertexEmbeddingResponse EmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &vertexEmbeddingResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + openaiResp := &openai.EmbeddingResponse{ + Model: modelName, + Data: make([]openai.EmbeddingResponseItem, 0, len(vertexEmbeddingResponse.Predictions)), + Usage: model.Usage{ + TotalTokens: 0, + }, + } + + for i, pred := range vertexEmbeddingResponse.Predictions { + openaiResp.Data = append(openaiResp.Data, openai.EmbeddingResponseItem{ + Index: i, + Embedding: pred.Embeddings.Values, + }) + } + + for _, pred := range vertexEmbeddingResponse.Predictions { + openaiResp.Usage.TotalTokens += pred.Embeddings.Statistics.TokenCount + } + + return gemini.EmbeddingResponseHandler(c, resp.StatusCode, openaiResp) +} diff --git a/relay/adaptor/vertexai/embedding/model.go b/relay/adaptor/vertexai/embedding/model.go new file mode 100644 index 00000000..e9079569 --- /dev/null +++ b/relay/adaptor/vertexai/embedding/model.go @@ -0,0 +1,46 @@ +package vertexai + +type EmbeddingTaskType string + +const ( + EmbeddingTaskTypeNone EmbeddingTaskType = "" + EmbeddingTaskTypeRetrievalQuery EmbeddingTaskType = "RETRIEVAL_QUERY" + EmbeddingTaskTypeRetrievalDocument EmbeddingTaskType = "RETRIEVAL_DOCUMENT" + EmbeddingTaskTypeSemanticSimilarity EmbeddingTaskType = "SEMANTIC_SIMILARITY" + EmbeddingTaskTypeClassification EmbeddingTaskType = "CLASSIFICATION" + EmbeddingTaskTypeClustering EmbeddingTaskType = "CLUSTERING" + EmbeddingTaskTypeQuestionAnswering EmbeddingTaskType = "QUESTION_ANSWERING" + EmbeddingTaskTypeFactVerification EmbeddingTaskType = "FACT_VERIFICATION" + EmbeddingTaskTypeCodeRetrievalQuery EmbeddingTaskType = "CODE_RETRIEVAL_QUERY" +) + +type EmbeddingRequest struct { + Instances []EmbeddingInstance `json:"instances"` + Parameters EmbeddingParams `json:"parameters"` +} + +type EmbeddingInstance struct { + Content string `json:"content"` + TaskType EmbeddingTaskType `json:"task_type,omitempty"` + Title string `json:"title,omitempty"` +} + +type EmbeddingParams struct { + AutoTruncate bool `json:"autoTruncate,omitempty"` + OutputDimensionality int `json:"outputDimensionality,omitempty"` + // Texts []string `json:"texts,omitempty"` +} + +type EmbeddingResponse struct { + Predictions []struct { + Embeddings EmbeddingData `json:"embeddings"` + } `json:"predictions"` +} + +type EmbeddingData struct { + Statistics struct { + Truncated bool `json:"truncated"` + TokenCount int `json:"token_count"` + } `json:"statistics"` + Values []float64 `json:"values"` +} diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go index f5b245d8..bf28a84c 100644 --- a/relay/adaptor/vertexai/gemini/adapter.go +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -2,6 +2,9 @@ package vertexai import ( "net/http" + "net/url" + "strconv" + "strings" "github.com/gin-gonic/gin" "github.com/pkg/errors" @@ -27,13 +30,57 @@ var ModelList = []string{ type Adaptor struct { } +func (a *Adaptor) parseGeminiChatGenerationThinking(model string) (string, *gemini.ThinkingConfig) { + thinkingConfig := &gemini.ThinkingConfig{ + IncludeThoughts: false, + ThinkingBudget: 0, + } + modelName := model + if strings.Contains(model, "?") { + parts := strings.Split(model, "?") + _modelName := parts[0] + if len(parts) >= 2 { + modelOptions, err := url.ParseQuery(parts[1]) + if err == nil && modelOptions != nil { + modelName = _modelName + hasThinkingFlag := modelOptions.Has("thinking") + if hasThinkingFlag { + thinkingConfig.IncludeThoughts = modelOptions.Get("thinking") == "1" + } + thinkingBudget := modelOptions.Get("thinking_budget") + if thinkingBudget != "" { + thinkingBudgetInt, err := strconv.Atoi(thinkingBudget) + if err == nil { + thinkingConfig.ThinkingBudget = thinkingBudgetInt + } + } + } + } + } + if strings.HasPrefix(modelName, "gemini-2.5") { + // 目前2.5的模型支持传递thinking config,且默认开启了thinking,不希望进入thinking模式需要显式传递thinkingConfig来关闭 + return modelName, thinkingConfig + } else { + // 其他模型暂时不支持 + if thinkingConfig != nil && (thinkingConfig.IncludeThoughts || thinkingConfig.ThinkingBudget > 0) { + // 为了后续一旦有其他模型支持了thinking,这里指定可以指定参数开启 + return modelName, thinkingConfig + } + return modelName, nil + } +} + func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } - + modelName, thinkingConfig := a.parseGeminiChatGenerationThinking(request.Model) + request.Model = modelName geminiRequest := gemini.ConvertRequest(*request) - c.Set(ctxkey.RequestModel, request.Model) + if thinkingConfig != nil { + geminiRequest.GenerationConfig.ThinkingConfig = thinkingConfig + } + c.Set(ctxkey.RequestModel, modelName) c.Set(ctxkey.ConvertedRequest, geminiRequest) return geminiRequest, nil } diff --git a/relay/adaptor/vertexai/model/model.go b/relay/adaptor/vertexai/model/model.go new file mode 100644 index 00000000..a39d38c2 --- /dev/null +++ b/relay/adaptor/vertexai/model/model.go @@ -0,0 +1,13 @@ +package model + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "net/http" +) + +type InnerAIAdapter interface { + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) +} diff --git a/relay/adaptor/vertexai/registry.go b/relay/adaptor/vertexai/registry.go index 41099f02..bc2fef50 100644 --- a/relay/adaptor/vertexai/registry.go +++ b/relay/adaptor/vertexai/registry.go @@ -1,20 +1,18 @@ package vertexai import ( - "net/http" - - "github.com/gin-gonic/gin" claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude" + embedding "github.com/songquanpeng/one-api/relay/adaptor/vertexai/embedding" gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini" - "github.com/songquanpeng/one-api/relay/meta" - "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/adaptor/vertexai/model" ) type VertexAIModelType int const ( - VerterAIClaude VertexAIModelType = iota + 1 - VerterAIGemini + VertexAIClaude VertexAIModelType = iota + 1 + VertexAIGemini + VertexAIEmbedding ) var modelMapping = map[string]VertexAIModelType{} @@ -23,28 +21,37 @@ var modelList = []string{} func init() { modelList = append(modelList, claude.ModelList...) for _, model := range claude.ModelList { - modelMapping[model] = VerterAIClaude + modelMapping[model] = VertexAIClaude } modelList = append(modelList, gemini.ModelList...) for _, model := range gemini.ModelList { - modelMapping[model] = VerterAIGemini + modelMapping[model] = VertexAIGemini + } + + modelList = append(modelList, embedding.ModelList...) + for _, model := range embedding.ModelList { + modelMapping[model] = VertexAIEmbedding } } -type innerAIAdapter interface { - ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) - DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) -} - -func GetAdaptor(model string) innerAIAdapter { +func GetAdaptor(model string) model.InnerAIAdapter { adaptorType := modelMapping[model] switch adaptorType { - case VerterAIClaude: + case VertexAIClaude: return &claude.Adaptor{} - case VerterAIGemini: + case VertexAIGemini: return &gemini.Adaptor{} + case VertexAIEmbedding: + return &embedding.Adaptor{} default: + adaptorType = PredictModelType(model) + switch adaptorType { + case VertexAIGemini: + return &gemini.Adaptor{} + case VertexAIEmbedding: + return &embedding.Adaptor{} + } return nil } }