mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
315 lines
6.4 KiB
Go
315 lines
6.4 KiB
Go
package controller
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/model"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type OpenAIModel struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
Permission []struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
AllowCreateEngine bool `json:"allow_create_engine"`
|
|
AllowSampling bool `json:"allow_sampling"`
|
|
AllowLogprobs bool `json:"allow_logprobs"`
|
|
AllowSearchIndices bool `json:"allow_search_indices"`
|
|
AllowView bool `json:"allow_view"`
|
|
AllowFineTuning bool `json:"allow_fine_tuning"`
|
|
Organization string `json:"organization"`
|
|
Group string `json:"group"`
|
|
IsBlocking bool `json:"is_blocking"`
|
|
} `json:"permission"`
|
|
Root string `json:"root"`
|
|
Parent string `json:"parent"`
|
|
}
|
|
|
|
type OpenAIModelsResponse struct {
|
|
Data []OpenAIModel `json:"data"`
|
|
Success bool `json:"success"`
|
|
}
|
|
|
|
func GetAllChannels(c *gin.Context) {
|
|
p, _ := strconv.Atoi(c.Query("p"))
|
|
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
|
if p < 0 {
|
|
p = 0
|
|
}
|
|
if pageSize < 0 {
|
|
pageSize = common.ItemsPerPage
|
|
}
|
|
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
|
channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": channels,
|
|
})
|
|
return
|
|
}
|
|
|
|
func FetchUpstreamModels(c *gin.Context) {
|
|
id, err := strconv.Atoi(c.Param("id"))
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
channel, err := model.GetChannelById(id, true)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
if channel.Type != common.ChannelTypeOpenAI {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": "仅支持 OpenAI 类型渠道",
|
|
})
|
|
return
|
|
}
|
|
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
|
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
}
|
|
result := OpenAIModelsResponse{}
|
|
err = json.Unmarshal(body, &result)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
}
|
|
if !result.Success {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": "上游返回错误",
|
|
})
|
|
}
|
|
|
|
var ids []string
|
|
for _, model := range result.Data {
|
|
ids = append(ids, model.ID)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": ids,
|
|
})
|
|
}
|
|
|
|
func FixChannelsAbilities(c *gin.Context) {
|
|
count, err := model.FixAbility()
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": count,
|
|
})
|
|
}
|
|
|
|
func SearchChannels(c *gin.Context) {
|
|
keyword := c.Query("keyword")
|
|
group := c.Query("group")
|
|
modelKeyword := c.Query("model")
|
|
//idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
|
channels, err := model.SearchChannels(keyword, group, modelKeyword)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": channels,
|
|
})
|
|
return
|
|
}
|
|
|
|
func GetChannel(c *gin.Context) {
|
|
id, err := strconv.Atoi(c.Param("id"))
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
channel, err := model.GetChannelById(id, false)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": channel,
|
|
})
|
|
return
|
|
}
|
|
|
|
func AddChannel(c *gin.Context) {
|
|
channel := model.Channel{}
|
|
err := c.ShouldBindJSON(&channel)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
channel.CreatedTime = common.GetTimestamp()
|
|
keys := strings.Split(channel.Key, "\n")
|
|
channels := make([]model.Channel, 0, len(keys))
|
|
for _, key := range keys {
|
|
if key == "" {
|
|
continue
|
|
}
|
|
localChannel := channel
|
|
localChannel.Key = key
|
|
channels = append(channels, localChannel)
|
|
}
|
|
err = model.BatchInsertChannels(channels)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
})
|
|
return
|
|
}
|
|
|
|
func DeleteChannel(c *gin.Context) {
|
|
id, _ := strconv.Atoi(c.Param("id"))
|
|
channel := model.Channel{Id: id}
|
|
err := channel.Delete()
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
})
|
|
return
|
|
}
|
|
|
|
func DeleteDisabledChannel(c *gin.Context) {
|
|
rows, err := model.DeleteDisabledChannel()
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": rows,
|
|
})
|
|
return
|
|
}
|
|
|
|
type ChannelBatch struct {
|
|
Ids []int `json:"ids"`
|
|
}
|
|
|
|
func DeleteChannelBatch(c *gin.Context) {
|
|
channelBatch := ChannelBatch{}
|
|
err := c.ShouldBindJSON(&channelBatch)
|
|
if err != nil || len(channelBatch.Ids) == 0 {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": "参数错误",
|
|
})
|
|
return
|
|
}
|
|
err = model.BatchDeleteChannels(channelBatch.Ids)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": len(channelBatch.Ids),
|
|
})
|
|
return
|
|
}
|
|
|
|
func UpdateChannel(c *gin.Context) {
|
|
channel := model.Channel{}
|
|
err := c.ShouldBindJSON(&channel)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
err = channel.Update()
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": channel,
|
|
})
|
|
return
|
|
}
|