diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index e132d2f..35f236a 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -6,12 +6,15 @@ import ( "github.com/gin-gonic/gin" "io" "net/http" + "os" "one-api/dto" "one-api/relay/channel" + "strings" relaycommon "one-api/relay/common" ) type Adaptor struct { + modelVersionMap map[string]string } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -25,18 +28,32 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { -} - -// 定义一个映射,存储模型名称和对应的版本 -var modelVersionMap = map[string]string{ - "gemini-1.5-pro-latest": "v1beta", - "gemini-1.5-flash-latest": "v1beta", - "gemini-ultra": "v1beta", + modelVersionMapStr := os.Getenv("GEMINI_MODEL_API") + if modelVersionMapStr == "" { + a.modelVersionMap = map[string]string{ + "gemini-1.5-pro-latest": "v1beta", + "gemini-1.5-pro-001": "v1beta", + "gemini-1.5-pro": "v1beta", + "gemini-1.5-pro-exp-0801": "v1beta", + "gemini-1.5-flash-latest": "v1beta", + "gemini-1.5-flash-001": "v1beta", + "gemini-1.5-flash": "v1beta", + "gemini-ultra": "v1beta", + } + return + } + a.modelVersionMap = make(map[string]string) + for _, pair := range strings.Split(modelVersionMapStr, ",") { + parts := strings.Split(pair, ":") + if len(parts) == 2 { + a.modelVersionMap[parts[0]] = parts[1] + } + } } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1" - version, beta := modelVersionMap[info.UpstreamModelName] + version, beta := a.modelVersionMap[info.UpstreamModelName] if !beta { if info.ApiVersion != "" { version = info.ApiVersion