diff --git a/controller/channel.go b/controller/channel.go index b98af41..d723ee6 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -1,6 +1,8 @@ package controller import ( + "encoding/json" + "fmt" "github.com/gin-gonic/gin" "net/http" "one-api/common" @@ -9,6 +11,34 @@ import ( "strings" ) +type OpenAIModel struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + Permission []struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + AllowCreateEngine bool `json:"allow_create_engine"` + AllowSampling bool `json:"allow_sampling"` + AllowLogprobs bool `json:"allow_logprobs"` + AllowSearchIndices bool `json:"allow_search_indices"` + AllowView bool `json:"allow_view"` + AllowFineTuning bool `json:"allow_fine_tuning"` + Organization string `json:"organization"` + Group string `json:"group"` + IsBlocking bool `json:"is_blocking"` + } `json:"permission"` + Root string `json:"root"` + Parent string `json:"parent"` +} + +type OpenAIModelsResponse struct { + Data []OpenAIModel `json:"data"` + Success bool `json:"success"` +} + func GetAllChannels(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) pageSize, _ := strconv.Atoi(c.Query("page_size")) @@ -35,6 +65,65 @@ func GetAllChannels(c *gin.Context) { return } +func FetchUpstreamModels(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if channel.Type != common.ChannelTypeOpenAI { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "仅支持 OpenAI 类型渠道", + }) + return + } + url := fmt.Sprintf("%s/v1/models", *channel.BaseURL) + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + } + result := OpenAIModelsResponse{} + err = json.Unmarshal(body, &result) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + } + if !result.Success { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "上游返回错误", + }) + } + + var ids []string + for _, model := range result.Data { + ids = append(ids, model.ID) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": ids, + }) +} + func FixChannelsAbilities(c *gin.Context) { count, err := model.FixAbility() if err != nil { diff --git a/router/api-router.go b/router/api-router.go index b98a94e..1dee81e 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -90,6 +90,8 @@ func SetApiRouter(router *gin.Engine) { channelRoute.DELETE("/:id", controller.DeleteChannel) channelRoute.POST("/batch", controller.DeleteChannelBatch) channelRoute.POST("/fix", controller.FixChannelsAbilities) + channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels) + } tokenRoute := apiRouter.Group("/token") tokenRoute.Use(middleware.UserAuth()) diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index a32a065..14f15db 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -15,6 +15,7 @@ import { Space, Spin, Button, + Tooltip, Input, Typography, Select, @@ -24,6 +25,7 @@ import { } from '@douyinfe/semi-ui'; import { Divider } from 'semantic-ui-react'; import { getChannelModels, loadChannelModels } from '../../components/utils.js'; +import axios from 'axios'; const MODEL_MAPPING_EXAMPLE = { 'gpt-3.5-turbo-0301': 'gpt-3.5-turbo', @@ -331,6 +333,7 @@ const EditChannel = (props) => { handleInputChange('models', localModels); }; + return ( <> { > 填入所有模型 + + +