添加同步上游模型列表按钮:添加提示以及支持已有渠道获取

This commit is contained in:
bubu 2024-05-21 22:16:20 +08:00
parent 6fe643b1c1
commit e2663a5c66
3 changed files with 145 additions and 28 deletions

View File

@ -1,6 +1,8 @@
package controller package controller
import ( import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
@ -9,6 +11,34 @@ import (
"strings" "strings"
) )
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
func GetAllChannels(c *gin.Context) { func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size")) pageSize, _ := strconv.Atoi(c.Query("page_size"))
@ -35,6 +65,65 @@ func GetAllChannels(c *gin.Context) {
return return
} }
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if channel.Type != common.ChannelTypeOpenAI {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "仅支持 OpenAI 类型渠道",
})
return
}
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
result := OpenAIModelsResponse{}
err = json.Unmarshal(body, &result)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
if !result.Success {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "上游返回错误",
})
}
var ids []string
for _, model := range result.Data {
ids = append(ids, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ids,
})
}
func FixChannelsAbilities(c *gin.Context) { func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility() count, err := model.FixAbility()
if err != nil { if err != nil {

View File

@ -90,6 +90,8 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.DELETE("/:id", controller.DeleteChannel) channelRoute.DELETE("/:id", controller.DeleteChannel)
channelRoute.POST("/batch", controller.DeleteChannelBatch) channelRoute.POST("/batch", controller.DeleteChannelBatch)
channelRoute.POST("/fix", controller.FixChannelsAbilities) channelRoute.POST("/fix", controller.FixChannelsAbilities)
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
} }
tokenRoute := apiRouter.Group("/token") tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth()) tokenRoute.Use(middleware.UserAuth())

View File

@ -15,6 +15,7 @@ import {
Space, Space,
Spin, Spin,
Button, Button,
Tooltip,
Input, Input,
Typography, Typography,
Select, Select,
@ -36,6 +37,8 @@ const STATUS_CODE_MAPPING_EXAMPLE = {
400: '500', 400: '500',
}; };
const fetchButtonTips = "1. 新建渠道时请求通过当前浏览器发出2. 编辑已有渠道,请求通过后端服务器发出"
function type2secretPrompt(type) { function type2secretPrompt(type) {
// inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') // inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
switch (type) { switch (type) {
@ -88,30 +91,51 @@ const EditChannel = (props) => {
const [fullModels, setFullModels] = useState([]); const [fullModels, setFullModels] = useState([]);
const [customModel, setCustomModel] = useState(''); const [customModel, setCustomModel] = useState('');
const fetchUpstreamModelList = (name) => { const fetchUpstreamModelList = async (name) => {
const url = inputs['base_url'] + '/v1/models'; if (inputs["type"] !== 1) {
const key = inputs['key'] showError("仅支持 OpenAI 接口格式")
axios.get(url, { return;
}
const models = inputs["models"] || []
let err = false;
if (isEdit) {
const res = await API.get("/api/channel/fetch_models/" + channelId)
if (res.data && res.data?.success) {
models.push(...res.data.data)
} else {
err = true
}
} else {
if (!inputs?.["key"]) {
showError("请填写密钥")
return;
}
try {
const host = new URL((inputs["base_url"] || "https://api.openai.com"))
const url = `https://${host.hostname}/v1/models`;
const key = inputs["key"];
const res = await axios.get(url, {
headers: { headers: {
'Authorization': `Bearer ${key}` 'Authorization': `Bearer ${key}`
} }
}).then((res) => { })
if (res.data && res.data?.success) { if (res.data && res.data?.success) {
const models = res.data.data.map((model) => model.id); models.push(...es.data.data.map((model) => model.id))
handleInputChange(name, models); } else {
err = true
}
}
catch (error) {
err = true
}
}
if (!err) {
handleInputChange(name, Array.from(new Set(models)));
showSuccess("获取模型列表成功"); showSuccess("获取模型列表成功");
} else { } else {
showError('获取模型列表失败'); showError('获取模型列表失败');
} }
}).catch((error) => {
console.log(error);
const errCode = error.response.status;
if (errCode === 401) {
showError(`获取模型列表失败,错误代码 ${errCode},请检查密钥是否填写`);
} else {
showError(`获取模型列表失败,错误代码 ${errCode}`);
}
})
} }
@ -575,6 +599,7 @@ const EditChannel = (props) => {
> >
填入所有模型 填入所有模型
</Button> </Button>
<Tooltip content={fetchButtonTips}>
<Button <Button
type='tertiary' type='tertiary'
onClick={() => { onClick={() => {
@ -583,6 +608,7 @@ const EditChannel = (props) => {
> >
获取模型列表 获取模型列表
</Button> </Button>
</Tooltip>
<Button <Button
type='warning' type='warning'
onClick={() => { onClick={() => {