diff --git a/controller/channel.go b/controller/channel.go index b98af41..d723ee6 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -1,6 +1,8 @@ package controller import ( + "encoding/json" + "fmt" "github.com/gin-gonic/gin" "net/http" "one-api/common" @@ -9,6 +11,34 @@ import ( "strings" ) +type OpenAIModel struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + Permission []struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + AllowCreateEngine bool `json:"allow_create_engine"` + AllowSampling bool `json:"allow_sampling"` + AllowLogprobs bool `json:"allow_logprobs"` + AllowSearchIndices bool `json:"allow_search_indices"` + AllowView bool `json:"allow_view"` + AllowFineTuning bool `json:"allow_fine_tuning"` + Organization string `json:"organization"` + Group string `json:"group"` + IsBlocking bool `json:"is_blocking"` + } `json:"permission"` + Root string `json:"root"` + Parent string `json:"parent"` +} + +type OpenAIModelsResponse struct { + Data []OpenAIModel `json:"data"` + Success bool `json:"success"` +} + func GetAllChannels(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) pageSize, _ := strconv.Atoi(c.Query("page_size")) @@ -35,6 +65,65 @@ func GetAllChannels(c *gin.Context) { return } +func FetchUpstreamModels(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if channel.Type != common.ChannelTypeOpenAI { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "仅支持 OpenAI 类型渠道", + }) + return + } + url := fmt.Sprintf("%s/v1/models", *channel.BaseURL) + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + } + result := OpenAIModelsResponse{} + err = json.Unmarshal(body, &result) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + } + if !result.Success { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "上游返回错误", + }) + } + + var ids []string + for _, model := range result.Data { + ids = append(ids, model.ID) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": ids, + }) +} + func FixChannelsAbilities(c *gin.Context) { count, err := model.FixAbility() if err != nil { diff --git a/router/api-router.go b/router/api-router.go index b98a94e..1dee81e 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -90,6 +90,8 @@ func SetApiRouter(router *gin.Engine) { channelRoute.DELETE("/:id", controller.DeleteChannel) channelRoute.POST("/batch", controller.DeleteChannelBatch) channelRoute.POST("/fix", controller.FixChannelsAbilities) + channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels) + } tokenRoute := apiRouter.Group("/token") tokenRoute.Use(middleware.UserAuth()) diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 2412002..2c39093 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -15,6 +15,7 @@ import { Space, Spin, Button, + Tooltip, Input, Typography, Select, @@ -36,6 +37,8 @@ const STATUS_CODE_MAPPING_EXAMPLE = { 400: '500', }; +const fetchButtonTips = "1. 新建渠道时,请求通过当前浏览器发出;2. 编辑已有渠道,请求通过后端服务器发出" + function type2secretPrompt(type) { // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') switch (type) { @@ -88,30 +91,51 @@ const EditChannel = (props) => { const [fullModels, setFullModels] = useState([]); const [customModel, setCustomModel] = useState(''); - const fetchUpstreamModelList = (name) => { - const url = inputs['base_url'] + '/v1/models'; - const key = inputs['key'] - axios.get(url, { - headers: { - 'Authorization': `Bearer ${key}` - } - }).then((res) => { + const fetchUpstreamModelList = async (name) => { + if (inputs["type"] !== 1) { + showError("仅支持 OpenAI 接口格式") + return; + } + const models = inputs["models"] || [] + let err = false; + if (isEdit) { + const res = await API.get("/api/channel/fetch_models/" + channelId) if (res.data && res.data?.success) { - const models = res.data.data.map((model) => model.id); - handleInputChange(name, models); - showSuccess("获取模型列表成功"); + models.push(...res.data.data) } else { - showError('获取模型列表失败'); + err = true } - }).catch((error) => { - console.log(error); - const errCode = error.response.status; - if (errCode === 401) { - showError(`获取模型列表失败,错误代码 ${errCode},请检查密钥是否填写`); - } else { - showError(`获取模型列表失败,错误代码 ${errCode}`); + } else { + if (!inputs?.["key"]) { + showError("请填写密钥") + return; } - }) + try { + const host = new URL((inputs["base_url"] || "https://api.openai.com")) + + const url = `https://${host.hostname}/v1/models`; + const key = inputs["key"]; + const res = await axios.get(url, { + headers: { + 'Authorization': `Bearer ${key}` + } + }) + if (res.data && res.data?.success) { + models.push(...es.data.data.map((model) => model.id)) + } else { + err = true + } + } + catch (error) { + err = true + } + } + if (!err) { + handleInputChange(name, Array.from(new Set(models))); + showSuccess("获取模型列表成功"); + } else { + showError('获取模型列表失败'); + } } @@ -575,14 +599,16 @@ const EditChannel = (props) => { > 填入所有模型 - + + +