diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index bdf8815b..26788ae7 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -8,19 +8,18 @@ import ( "net/http" "strings" - "github.com/songquanpeng/one-api/common/render" - + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/relay/adaptor/geminiv2" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - - "github.com/gin-gonic/gin" ) // https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn @@ -61,12 +60,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { }, }, GenerationConfig: ChatGenerationConfig{ - Temperature: textRequest.Temperature, - TopP: textRequest.TopP, - MaxOutputTokens: textRequest.MaxTokens, - ResponseModalities: []string{ - "TEXT", "IMAGE", - }, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + MaxOutputTokens: textRequest.MaxTokens, + ResponseModalities: geminiv2.GetModelModalities(textRequest.Model), }, } if textRequest.ResponseFormat != nil { diff --git a/relay/adaptor/geminiv2/constants.go b/relay/adaptor/geminiv2/constants.go index 1921bcac..ba42bdb9 100644 --- a/relay/adaptor/geminiv2/constants.go +++ b/relay/adaptor/geminiv2/constants.go @@ -1,5 +1,7 @@ package geminiv2 +import "strings" + // https://ai.google.dev/models/gemini var ModelList = []string{ @@ -14,3 +16,17 @@ var ModelList = []string{ "gemini-2.0-pro-exp-02-05", "gemini-2.0-flash-exp-image-generation", } + +const ( + ModalityText = "TEXT" + ModalityImage = "IMAGE" +) + +// GetModelModalities returns the modalities of the model. +func GetModelModalities(model string) []string { + if strings.Contains(model, "-image-generation") { + return []string{ModalityText, ModalityImage} + } + + return []string{ModalityText} +}