diff --git a/controller/channel.go b/controller/channel.go
index b98af41..d723ee6 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -1,6 +1,8 @@
package controller
import (
+ "encoding/json"
+ "fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
@@ -9,6 +11,34 @@ import (
"strings"
)
+type OpenAIModel struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ OwnedBy string `json:"owned_by"`
+ Permission []struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ AllowCreateEngine bool `json:"allow_create_engine"`
+ AllowSampling bool `json:"allow_sampling"`
+ AllowLogprobs bool `json:"allow_logprobs"`
+ AllowSearchIndices bool `json:"allow_search_indices"`
+ AllowView bool `json:"allow_view"`
+ AllowFineTuning bool `json:"allow_fine_tuning"`
+ Organization string `json:"organization"`
+ Group string `json:"group"`
+ IsBlocking bool `json:"is_blocking"`
+ } `json:"permission"`
+ Root string `json:"root"`
+ Parent string `json:"parent"`
+}
+
+type OpenAIModelsResponse struct {
+ Data []OpenAIModel `json:"data"`
+ Success bool `json:"success"`
+}
+
func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
@@ -35,6 +65,65 @@ func GetAllChannels(c *gin.Context) {
return
}
+func FetchUpstreamModels(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ channel, err := model.GetChannelById(id, true)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ if channel.Type != common.ChannelTypeOpenAI {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "仅支持 OpenAI 类型渠道",
+ })
+ return
+ }
+ url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ }
+ result := OpenAIModelsResponse{}
+ err = json.Unmarshal(body, &result)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ }
+ if !result.Success {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "上游返回错误",
+ })
+ }
+
+ var ids []string
+ for _, model := range result.Data {
+ ids = append(ids, model.ID)
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": ids,
+ })
+}
+
func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility()
if err != nil {
diff --git a/router/api-router.go b/router/api-router.go
index b98a94e..1dee81e 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -90,6 +90,8 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.DELETE("/:id", controller.DeleteChannel)
channelRoute.POST("/batch", controller.DeleteChannelBatch)
channelRoute.POST("/fix", controller.FixChannelsAbilities)
+ channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
+
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index 2412002..2c39093 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -15,6 +15,7 @@ import {
Space,
Spin,
Button,
+ Tooltip,
Input,
Typography,
Select,
@@ -36,6 +37,8 @@ const STATUS_CODE_MAPPING_EXAMPLE = {
400: '500',
};
+const fetchButtonTips = "1. 新建渠道时,请求通过当前浏览器发出;2. 编辑已有渠道,请求通过后端服务器发出"
+
function type2secretPrompt(type) {
// inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
switch (type) {
@@ -88,30 +91,51 @@ const EditChannel = (props) => {
const [fullModels, setFullModels] = useState([]);
const [customModel, setCustomModel] = useState('');
- const fetchUpstreamModelList = (name) => {
- const url = inputs['base_url'] + '/v1/models';
- const key = inputs['key']
- axios.get(url, {
- headers: {
- 'Authorization': `Bearer ${key}`
- }
- }).then((res) => {
+ const fetchUpstreamModelList = async (name) => {
+ if (inputs["type"] !== 1) {
+ showError("仅支持 OpenAI 接口格式")
+ return;
+ }
+ const models = inputs["models"] || []
+ let err = false;
+ if (isEdit) {
+ const res = await API.get("/api/channel/fetch_models/" + channelId)
if (res.data && res.data?.success) {
- const models = res.data.data.map((model) => model.id);
- handleInputChange(name, models);
- showSuccess("获取模型列表成功");
+ models.push(...res.data.data)
} else {
- showError('获取模型列表失败');
+ err = true
}
- }).catch((error) => {
- console.log(error);
- const errCode = error.response.status;
- if (errCode === 401) {
- showError(`获取模型列表失败,错误代码 ${errCode},请检查密钥是否填写`);
- } else {
- showError(`获取模型列表失败,错误代码 ${errCode}`);
+ } else {
+ if (!inputs?.["key"]) {
+ showError("请填写密钥")
+ return;
}
- })
+ try {
+ const host = new URL((inputs["base_url"] || "https://api.openai.com"))
+
+ const url = `https://${host.hostname}/v1/models`;
+ const key = inputs["key"];
+ const res = await axios.get(url, {
+ headers: {
+ 'Authorization': `Bearer ${key}`
+ }
+ })
+ if (res.data && res.data?.success) {
+ models.push(...es.data.data.map((model) => model.id))
+ } else {
+ err = true
+ }
+ }
+ catch (error) {
+ err = true
+ }
+ }
+ if (!err) {
+ handleInputChange(name, Array.from(new Set(models)));
+ showSuccess("获取模型列表成功");
+ } else {
+ showError('获取模型列表失败');
+ }
}
@@ -575,14 +599,16 @@ const EditChannel = (props) => {
>
填入所有模型
-
+
+
+