From 22a98c58796ccfc7b4fac449d25cb032e8494a97 Mon Sep 17 00:00:00 2001 From: HowieWu <98788152+utopeadia@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:20:26 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9Gemini=E7=89=88=E6=9C=AC?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 使用GEMINI_MODEL_API环境变量覆盖默认版本映射,使用","分隔不同模型和版本 -e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta,gemini-1.5-pro:v1beta,gemini-1.5-flash-latest:v1beta,gemini-1.5-flash-001:v1beta,gemini-1.5-flash:v1beta,gemini-ultra:v1beta,gemini-1.5-pro-exp-0801:v1beta" --- relay/channel/gemini/adaptor.go | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index e132d2f..35f236a 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -6,12 +6,15 @@ import ( "github.com/gin-gonic/gin" "io" "net/http" + "os" "one-api/dto" "one-api/relay/channel" + "strings" relaycommon "one-api/relay/common" ) type Adaptor struct { + modelVersionMap map[string]string } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -25,18 +28,32 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { -} - -// 定义一个映射,存储模型名称和对应的版本 -var modelVersionMap = map[string]string{ - "gemini-1.5-pro-latest": "v1beta", - "gemini-1.5-flash-latest": "v1beta", - "gemini-ultra": "v1beta", + modelVersionMapStr := os.Getenv("GEMINI_MODEL_API") + if modelVersionMapStr == "" { + a.modelVersionMap = map[string]string{ + "gemini-1.5-pro-latest": "v1beta", + "gemini-1.5-pro-001": "v1beta", + "gemini-1.5-pro": "v1beta", + "gemini-1.5-pro-exp-0801": "v1beta", + "gemini-1.5-flash-latest": "v1beta", + "gemini-1.5-flash-001": "v1beta", + "gemini-1.5-flash": "v1beta", + "gemini-ultra": "v1beta", + } + return + } + a.modelVersionMap = make(map[string]string) + for _, pair := range strings.Split(modelVersionMapStr, ",") { + parts := strings.Split(pair, ":") + if len(parts) == 2 { + a.modelVersionMap[parts[0]] = parts[1] + } + } } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1" - version, beta := modelVersionMap[info.UpstreamModelName] + version, beta := a.modelVersionMap[info.UpstreamModelName] if !beta { if info.ApiVersion != "" { version = info.ApiVersion