mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-21 15:36:49 +08:00
Merge branch 'main' into customer-function
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/controller"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -133,8 +133,8 @@ func GitHubOAuth(c *gin.Context) {
|
||||
user.DisplayName = "GitHub User"
|
||||
}
|
||||
user.Email = githubUser.Email
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
user.Role = model.RoleCommonUser
|
||||
user.Status = model.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -152,14 +152,14 @@ func GitHubOAuth(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
if user.Status != model.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
controller.SetupLogin(&user, c)
|
||||
}
|
||||
|
||||
func GitHubBind(c *gin.Context) {
|
||||
@@ -219,7 +219,7 @@ func GitHubBind(c *gin.Context) {
|
||||
|
||||
func GenerateOAuthCode(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := helper.GetRandomString(12)
|
||||
state := random.GetRandomString(12)
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
200
controller/auth/lark.go
Normal file
200
controller/auth/lark.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/controller"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LarkOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
type LarkUser struct {
|
||||
Name string `json:"name"`
|
||||
OpenID string `json:"open_id"`
|
||||
}
|
||||
|
||||
func getLarkUserInfoByCode(code string) (*LarkUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
values := map[string]string{
|
||||
"client_id": config.LarkClientId,
|
||||
"client_secret": config.LarkClientSecret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress),
|
||||
}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse LarkOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
|
||||
}
|
||||
var larkUser LarkUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&larkUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &larkUser, nil
|
||||
}
|
||||
|
||||
func LarkOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LarkBind(c)
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
larkUser, err := getLarkUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LarkId: larkUser.OpenID,
|
||||
}
|
||||
if model.IsLarkIdAlreadyTaken(user.LarkId) {
|
||||
err := user.FillUserByLarkId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if config.RegisterEnabled {
|
||||
user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if larkUser.Name != "" {
|
||||
user.DisplayName = larkUser.Name
|
||||
} else {
|
||||
user.DisplayName = "Lark User"
|
||||
}
|
||||
user.Role = model.RoleCommonUser
|
||||
user.Status = model.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != model.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.SetupLogin(&user, c)
|
||||
}
|
||||
|
||||
func LarkBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
larkUser, err := getLarkUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LarkId: larkUser.OpenID,
|
||||
}
|
||||
if model.IsLarkIdAlreadyTaken(user.LarkId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该飞书账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.LarkId = larkUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
package controller
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/controller"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -83,8 +84,8 @@ func WeChatAuth(c *gin.Context) {
|
||||
if config.RegisterEnabled {
|
||||
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = "WeChat User"
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
user.Role = model.RoleCommonUser
|
||||
user.Status = model.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -102,14 +103,14 @@ func WeChatAuth(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
if user.Status != model.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
controller.SetupLogin(&user, c)
|
||||
}
|
||||
|
||||
func WeChatBind(c *gin.Context) {
|
||||
@@ -136,7 +137,7 @@ func WeChatBind(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
user := model.User{
|
||||
Id: id,
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
@@ -14,13 +15,13 @@ func GetSubscription(c *gin.Context) {
|
||||
var token *model.Token
|
||||
var expiredTime int64
|
||||
if config.DisplayTokenStatEnabled {
|
||||
tokenId := c.GetInt("token_id")
|
||||
tokenId := c.GetInt(ctxkey.TokenId)
|
||||
token, err = model.GetTokenById(tokenId)
|
||||
expiredTime = token.ExpiredTime
|
||||
remainQuota = token.RemainQuota
|
||||
usedQuota = token.UsedQuota
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
remainQuota, err = model.GetUserQuota(userId)
|
||||
if err != nil {
|
||||
usedQuota, err = model.GetUserUsedQuota(userId)
|
||||
@@ -64,11 +65,11 @@ func GetUsage(c *gin.Context) {
|
||||
var err error
|
||||
var token *model.Token
|
||||
if config.DisplayTokenStatEnabled {
|
||||
tokenId := c.GetInt("token_id")
|
||||
tokenId := c.GetInt(ctxkey.TokenId)
|
||||
token, err = model.GetTokenById(tokenId)
|
||||
quota = token.UsedQuota
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
quota, err = model.GetUserUsedQuota(userId)
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/monitor"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/client"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -96,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
||||
for k := range headers {
|
||||
req.Header.Add(k, headers.Get(k))
|
||||
}
|
||||
res, err := util.HTTPClient.Do(req)
|
||||
res, err := client.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -204,28 +204,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := channeltype.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
channel.BaseURL = &baseURL
|
||||
}
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeOpenAI:
|
||||
case channeltype.OpenAI:
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
case common.ChannelTypeAzure:
|
||||
case channeltype.Azure:
|
||||
return 0, errors.New("尚未实现")
|
||||
case common.ChannelTypeCustom:
|
||||
case channeltype.Custom:
|
||||
baseURL = channel.GetBaseURL()
|
||||
case common.ChannelTypeCloseAI:
|
||||
case channeltype.CloseAI:
|
||||
return updateChannelCloseAIBalance(channel)
|
||||
case common.ChannelTypeOpenAISB:
|
||||
case channeltype.OpenAISB:
|
||||
return updateChannelOpenAISBBalance(channel)
|
||||
case common.ChannelTypeAIProxy:
|
||||
case channeltype.AIProxy:
|
||||
return updateChannelAIProxyBalance(channel)
|
||||
case common.ChannelTypeAPI2GPT:
|
||||
case channeltype.API2GPT:
|
||||
return updateChannelAPI2GPTBalance(channel)
|
||||
case common.ChannelTypeAIGC2D:
|
||||
case channeltype.AIGC2D:
|
||||
return updateChannelAIGC2DBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
@@ -301,11 +301,11 @@ func updateAllChannelsBalance() error {
|
||||
return err
|
||||
}
|
||||
for _, channel := range channels {
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
if channel.Status != model.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
// TODO: support Azure
|
||||
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
||||
if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom {
|
||||
continue
|
||||
}
|
||||
balance, err := updateChannelBalance(channel)
|
||||
|
||||
@@ -5,17 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/message"
|
||||
"github.com/songquanpeng/one-api/middleware"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/monitor"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/helper"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -25,6 +14,20 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/message"
|
||||
"github.com/songquanpeng/one-api/middleware"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/monitor"
|
||||
relay "github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/controller"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -53,27 +56,37 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set(ctxkey.Channel, channel.Type)
|
||||
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
|
||||
cfg, _ := channel.LoadConfig()
|
||||
c.Set(ctxkey.Config, cfg)
|
||||
middleware.SetupContextForSelectedChannel(c, channel, "")
|
||||
meta := util.GetRelayMeta(c)
|
||||
apiType := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := helper.GetAdaptor(apiType)
|
||||
meta := meta.GetByContext(c)
|
||||
apiType := channeltype.ToAPIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
adaptor.Init(meta)
|
||||
modelName := adaptor.GetModelList()[0]
|
||||
if !strings.Contains(channel.Models, modelName) {
|
||||
var modelName string
|
||||
modelList := adaptor.GetModelList()
|
||||
modelMap := channel.GetModelMapping()
|
||||
if len(modelList) != 0 {
|
||||
modelName = modelList[0]
|
||||
}
|
||||
if modelName == "" || !strings.Contains(channel.Models, modelName) {
|
||||
modelNames := strings.Split(channel.Models, ",")
|
||||
if len(modelNames) > 0 {
|
||||
modelName = modelNames[0]
|
||||
}
|
||||
if modelMap != nil && modelMap[modelName] != "" {
|
||||
modelName = modelMap[modelName]
|
||||
}
|
||||
}
|
||||
request := buildTestRequest()
|
||||
request.Model = modelName
|
||||
meta.OriginModelName, meta.ActualModelName = modelName, modelName
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@@ -81,14 +94,15 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
logger.SysLog(string(jsonData))
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err := util.RelayErrorHandler(resp)
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
err := controller.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
||||
}
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
@@ -171,7 +185,7 @@ func testChannels(notify bool, scope string) error {
|
||||
}
|
||||
go func() {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
isChannelEnabled := channel.Status == model.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiErr := testChannel(channel)
|
||||
tok := time.Now()
|
||||
@@ -184,10 +198,10 @@ func testChannels(notify bool, scope string) error {
|
||||
_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error())
|
||||
}
|
||||
}
|
||||
if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) {
|
||||
if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) {
|
||||
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
|
||||
if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) {
|
||||
monitor.EnableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
@@ -197,7 +211,7 @@ func testChannels(notify bool, scope string) error {
|
||||
testAllChannelsRunning = false
|
||||
testAllChannelsLock.Unlock()
|
||||
if notify {
|
||||
err := message.Notify(message.ByAll, "通道测试完成", "", "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
|
||||
err := message.Notify(message.ByAll, "渠道测试完成", "", "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常")
|
||||
if err != nil {
|
||||
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||
}
|
||||
|
||||
@@ -2,13 +2,13 @@ package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName := range common.GroupRatio {
|
||||
for groupName := range billingratio.GroupRatio {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -72,7 +73,7 @@ func GetUserLogs(c *gin.Context) {
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
@@ -114,7 +115,7 @@ func SearchAllLogs(c *gin.Context) {
|
||||
|
||||
func SearchUserLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
userId := c.GetInt("id")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
logs, err := model.SearchUserLogs(userId, keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -153,7 +154,7 @@ func GetLogsStat(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetLogsSelfStat(c *gin.Context) {
|
||||
username := c.GetString("username")
|
||||
username := c.GetString(ctxkey.Username)
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
|
||||
@@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) {
|
||||
"email_verification": config.EmailVerificationEnabled,
|
||||
"github_oauth": config.GitHubOAuthEnabled,
|
||||
"github_client_id": config.GitHubClientId,
|
||||
"lark_client_id": config.LarkClientId,
|
||||
"system_name": config.SystemName,
|
||||
"logo": config.Logo,
|
||||
"footer_html": config.Footer,
|
||||
|
||||
@@ -3,13 +3,16 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/helper"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
relay "github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -39,8 +42,8 @@ type OpenAIModels struct {
|
||||
Parent *string `json:"parent"`
|
||||
}
|
||||
|
||||
var openAIModels []OpenAIModels
|
||||
var openAIModelsMap map[string]OpenAIModels
|
||||
var models []OpenAIModels
|
||||
var modelsMap map[string]OpenAIModels
|
||||
var channelId2Models map[int][]string
|
||||
|
||||
func init() {
|
||||
@@ -60,15 +63,15 @@ func init() {
|
||||
IsBlocking: false,
|
||||
})
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
for i := 0; i < constant.APITypeDummy; i++ {
|
||||
if i == constant.APITypeAIProxyLibrary {
|
||||
for i := 0; i < apitype.Dummy; i++ {
|
||||
if i == apitype.AIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
adaptor := helper.GetAdaptor(i)
|
||||
adaptor := relay.GetAdaptor(i)
|
||||
channelName := adaptor.GetChannelName()
|
||||
modelNames := adaptor.GetModelList()
|
||||
for _, modelName := range modelNames {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
models = append(models, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
@@ -80,12 +83,12 @@ func init() {
|
||||
}
|
||||
}
|
||||
for _, channelType := range openai.CompatibleChannels {
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
if channelType == channeltype.Azure {
|
||||
continue
|
||||
}
|
||||
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
|
||||
for _, modelName := range channelModelList {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
models = append(models, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
@@ -96,14 +99,14 @@ func init() {
|
||||
})
|
||||
}
|
||||
}
|
||||
openAIModelsMap = make(map[string]OpenAIModels)
|
||||
for _, model := range openAIModels {
|
||||
openAIModelsMap[model.Id] = model
|
||||
modelsMap = make(map[string]OpenAIModels)
|
||||
for _, model := range models {
|
||||
modelsMap[model.Id] = model
|
||||
}
|
||||
channelId2Models = make(map[int][]string)
|
||||
for i := 1; i < common.ChannelTypeDummy; i++ {
|
||||
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i))
|
||||
meta := &util.RelayMeta{
|
||||
for i := 1; i < channeltype.Dummy; i++ {
|
||||
adaptor := relay.GetAdaptor(channeltype.ToAPIType(i))
|
||||
meta := &meta.Meta{
|
||||
ChannelType: i,
|
||||
}
|
||||
adaptor.Init(meta)
|
||||
@@ -119,16 +122,55 @@ func DashboardListModels(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
func ListAllModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"object": "list",
|
||||
"data": openAIModels,
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
var availableModels []string
|
||||
if c.GetString(ctxkey.AvailableModels) != "" {
|
||||
availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
|
||||
} else {
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
|
||||
}
|
||||
modelSet := make(map[string]bool)
|
||||
for _, availableModel := range availableModels {
|
||||
modelSet[availableModel] = true
|
||||
}
|
||||
availableOpenAIModels := make([]OpenAIModels, 0)
|
||||
for _, model := range models {
|
||||
if _, ok := modelSet[model.Id]; ok {
|
||||
modelSet[model.Id] = false
|
||||
availableOpenAIModels = append(availableOpenAIModels, model)
|
||||
}
|
||||
}
|
||||
for modelName, ok := range modelSet {
|
||||
if ok {
|
||||
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"object": "list",
|
||||
"data": availableOpenAIModels,
|
||||
})
|
||||
}
|
||||
|
||||
func RetrieveModel(c *gin.Context) {
|
||||
modelId := c.Param("model")
|
||||
if model, ok := openAIModelsMap[modelId]; ok {
|
||||
if model, ok := modelsMap[modelId]; ok {
|
||||
c.JSON(200, model)
|
||||
} else {
|
||||
Error := relaymodel.Error{
|
||||
@@ -142,3 +184,30 @@ func RetrieveModel(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetUserAvailableModels(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
userGroup, err := model.CacheGetUserGroup(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
models, err := model.CacheGetGroupModels(ctx, userGroup)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": models,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ package controller
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -106,9 +108,9 @@ func AddRedemption(c *gin.Context) {
|
||||
}
|
||||
var keys []string
|
||||
for i := 0; i < redemption.Count; i++ {
|
||||
key := helper.GetUUID()
|
||||
key := random.GetUUID()
|
||||
cleanRedemption := model.Redemption{
|
||||
UserId: c.GetInt("id"),
|
||||
UserId: c.GetInt(ctxkey.Id),
|
||||
Name: redemption.Name,
|
||||
Key: key,
|
||||
CreatedTime: helper.GetTimestamp(),
|
||||
|
||||
@@ -9,31 +9,30 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/middleware"
|
||||
dbmodel "github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/monitor"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/controller"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat
|
||||
|
||||
func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||
func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||
var err *model.ErrorWithStatusCode
|
||||
switch relayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
case relaymode.ImagesGenerations:
|
||||
err = controller.RelayImageHelper(c, relayMode)
|
||||
case constant.RelayModeAudioSpeech:
|
||||
case relaymode.AudioSpeech:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranslation:
|
||||
case relaymode.AudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
case relaymode.AudioTranscription:
|
||||
err = controller.RelayAudioHelper(c, relayMode)
|
||||
default:
|
||||
err = controller.RelayTextHelper(c)
|
||||
@@ -43,23 +42,23 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
relayMode := relaymode.GetByPath(c.Request.URL.Path)
|
||||
if config.DebugEnabled {
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
logger.Debugf(ctx, "request body: %s", string(requestBody))
|
||||
}
|
||||
channelId := c.GetInt("channel_id")
|
||||
bizErr := relay(c, relayMode)
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
bizErr := relayHelper(c, relayMode)
|
||||
if bizErr == nil {
|
||||
monitor.Emit(channelId, true)
|
||||
return
|
||||
}
|
||||
lastFailedChannelId := channelId
|
||||
channelName := c.GetString("channel_name")
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
channelName := c.GetString(ctxkey.ChannelName)
|
||||
group := c.GetString(ctxkey.Group)
|
||||
originalModel := c.GetString(ctxkey.OriginalModel)
|
||||
go processChannelRelayError(ctx, channelId, channelName, bizErr)
|
||||
requestId := c.GetString(logger.RequestIdKey)
|
||||
requestId := c.GetString(helper.RequestIdKey)
|
||||
retryTimes := config.RetryTimes
|
||||
if !shouldRetry(c, bizErr.StatusCode) {
|
||||
logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
|
||||
@@ -68,7 +67,7 @@ func Relay(c *gin.Context) {
|
||||
for i := retryTimes; i > 0; i-- {
|
||||
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
|
||||
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -79,13 +78,13 @@ func Relay(c *gin.Context) {
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
bizErr = relay(c, relayMode)
|
||||
bizErr = relayHelper(c, relayMode)
|
||||
if bizErr == nil {
|
||||
return
|
||||
}
|
||||
channelId := c.GetInt("channel_id")
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
lastFailedChannelId = channelId
|
||||
channelName := c.GetString("channel_name")
|
||||
channelName := c.GetString(ctxkey.ChannelName)
|
||||
go processChannelRelayError(ctx, channelId, channelName, bizErr)
|
||||
}
|
||||
if bizErr != nil {
|
||||
@@ -100,7 +99,7 @@ func Relay(c *gin.Context) {
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, statusCode int) bool {
|
||||
if _, ok := c.Get("specific_channel_id"); ok {
|
||||
if _, ok := c.Get(ctxkey.SpecificChannelId); ok {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
@@ -122,7 +121,7 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st
|
||||
|
||||
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
|
||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||
monitor.DisableChannel(channelId, channelName, err.Message)
|
||||
} else {
|
||||
monitor.Emit(channelId, false)
|
||||
|
||||
@@ -1,22 +1,28 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/network"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func GetAllTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage)
|
||||
|
||||
order := c.Query("order")
|
||||
tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -49,7 +55,7 @@ func GetNameByToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
func SearchTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
keyword := c.Query("keyword")
|
||||
tokens, err := model.SearchUserTokens(userId, keyword)
|
||||
if err != nil {
|
||||
@@ -69,7 +75,7 @@ func SearchTokens(c *gin.Context) {
|
||||
|
||||
func GetToken(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
userId := c.GetInt("id")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -94,8 +100,8 @@ func GetToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetTokenStatus(c *gin.Context) {
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
tokenId := c.GetInt(ctxkey.TokenId)
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
token, err := model.GetTokenByIds(tokenId, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -117,6 +123,19 @@ func GetTokenStatus(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func validateToken(c *gin.Context, token model.Token) error {
|
||||
if len(token.Name) > 30 {
|
||||
return fmt.Errorf("令牌名称过长")
|
||||
}
|
||||
if token.Subnet != nil && *token.Subnet != "" {
|
||||
err := network.IsValidSubnets(*token.Subnet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("无效的网段:%s", err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
@@ -127,22 +146,26 @@ func AddToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
err = validateToken(c, token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌名称过长",
|
||||
"message": fmt.Sprintf("参数错误:%s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cleanToken := model.Token{
|
||||
UserId: c.GetInt("id"),
|
||||
UserId: c.GetInt(ctxkey.Id),
|
||||
Name: token.Name,
|
||||
Key: helper.GenerateKey(),
|
||||
Key: random.GenerateKey(),
|
||||
CreatedTime: helper.GetTimestamp(),
|
||||
AccessedTime: helper.GetTimestamp(),
|
||||
ExpiredTime: token.ExpiredTime,
|
||||
RemainQuota: token.RemainQuota,
|
||||
UnlimitedQuota: token.UnlimitedQuota,
|
||||
Models: token.Models,
|
||||
Subnet: token.Subnet,
|
||||
}
|
||||
err = cleanToken.Insert()
|
||||
if err != nil {
|
||||
@@ -155,13 +178,14 @@ func AddToken(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": cleanToken,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteToken(c *gin.Context) {
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
userId := c.GetInt("id")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
err := model.DeleteTokenById(id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -178,7 +202,7 @@ func DeleteToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
func UpdateToken(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
statusOnly := c.Query("status_only")
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
@@ -189,10 +213,11 @@ func UpdateToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
err = validateToken(c, token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌名称过长",
|
||||
"message": fmt.Sprintf("参数错误:%s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -204,15 +229,15 @@ func UpdateToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if token.Status == common.TokenStatusEnabled {
|
||||
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
|
||||
if token.Status == model.TokenStatusEnabled {
|
||||
if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
|
||||
})
|
||||
return
|
||||
}
|
||||
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
|
||||
if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
|
||||
@@ -228,6 +253,8 @@ func UpdateToken(c *gin.Context) {
|
||||
cleanToken.ExpiredTime = token.ExpiredTime
|
||||
cleanToken.RemainQuota = token.RemainQuota
|
||||
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
||||
cleanToken.Models = token.Models
|
||||
cleanToken.Subnet = token.Subnet
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -58,11 +59,11 @@ func Login(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
SetupLogin(&user, c)
|
||||
}
|
||||
|
||||
// setup session & cookies and then return user info
|
||||
func setupLogin(user *model.User, c *gin.Context) {
|
||||
func SetupLogin(user *model.User, c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
session.Set("id", user.Id)
|
||||
session.Set("username", user.Username)
|
||||
@@ -184,7 +185,10 @@ func GetAllUsers(c *gin.Context) {
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage)
|
||||
|
||||
order := c.DefaultQuery("order", "")
|
||||
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -192,12 +196,12 @@ func GetAllUsers(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": users,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchUsers(c *gin.Context) {
|
||||
@@ -235,8 +239,8 @@ func GetUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= user.Role && myRole != common.RoleRootUser {
|
||||
myRole := c.GetInt(ctxkey.Role)
|
||||
if myRole <= user.Role && myRole != model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权获取同级或更高等级用户的信息",
|
||||
@@ -252,7 +256,7 @@ func GetUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetUserDashboard(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
now := time.Now()
|
||||
startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
|
||||
endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
|
||||
@@ -275,7 +279,7 @@ func GetUserDashboard(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GenerateAccessToken(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -284,7 +288,7 @@ func GenerateAccessToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
user.AccessToken = helper.GetUUID()
|
||||
user.AccessToken = random.GetUUID()
|
||||
|
||||
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -311,7 +315,7 @@ func GenerateAccessToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetAffCode(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -321,7 +325,7 @@ func GetAffCode(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if user.AffCode == "" {
|
||||
user.AffCode = helper.GetRandomString(4)
|
||||
user.AffCode = random.GetRandomString(4)
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -339,7 +343,7 @@ func GetAffCode(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -384,15 +388,15 @@ func UpdateUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= originUser.Role && myRole != common.RoleRootUser {
|
||||
myRole := c.GetInt(ctxkey.Role)
|
||||
if myRole <= originUser.Role && myRole != model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权更新同权限等级或更高权限等级的用户信息",
|
||||
})
|
||||
return
|
||||
}
|
||||
if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
|
||||
if myRole <= updatedUser.Role && myRole != model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
|
||||
@@ -442,7 +446,7 @@ func UpdateSelf(c *gin.Context) {
|
||||
}
|
||||
|
||||
cleanUser := model.User{
|
||||
Id: c.GetInt("id"),
|
||||
Id: c.GetInt(ctxkey.Id),
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.DisplayName,
|
||||
@@ -506,7 +510,7 @@ func DeleteSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
|
||||
if user.Role == common.RoleRootUser {
|
||||
if user.Role == model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不能删除超级管理员账户",
|
||||
@@ -608,7 +612,7 @@ func ManageUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= user.Role && myRole != common.RoleRootUser {
|
||||
if myRole <= user.Role && myRole != model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权更新同权限等级或更高权限等级的用户信息",
|
||||
@@ -617,8 +621,8 @@ func ManageUser(c *gin.Context) {
|
||||
}
|
||||
switch req.Action {
|
||||
case "disable":
|
||||
user.Status = common.UserStatusDisabled
|
||||
if user.Role == common.RoleRootUser {
|
||||
user.Status = model.UserStatusDisabled
|
||||
if user.Role == model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法禁用超级管理员用户",
|
||||
@@ -626,9 +630,9 @@ func ManageUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "enable":
|
||||
user.Status = common.UserStatusEnabled
|
||||
user.Status = model.UserStatusEnabled
|
||||
case "delete":
|
||||
if user.Role == common.RoleRootUser {
|
||||
if user.Role == model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法删除超级管理员用户",
|
||||
@@ -643,37 +647,37 @@ func ManageUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "promote":
|
||||
if myRole != common.RoleRootUser {
|
||||
if myRole != model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "普通管理员用户无法提升其他用户为管理员",
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Role >= common.RoleAdminUser {
|
||||
if user.Role >= model.RoleAdminUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该用户已经是管理员",
|
||||
})
|
||||
return
|
||||
}
|
||||
user.Role = common.RoleAdminUser
|
||||
user.Role = model.RoleAdminUser
|
||||
case "demote":
|
||||
if user.Role == common.RoleRootUser {
|
||||
if user.Role == model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法降级超级管理员用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Role == common.RoleCommonUser {
|
||||
if user.Role == model.RoleCommonUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该用户已经是普通用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Role = model.RoleCommonUser
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
@@ -727,7 +731,7 @@ func EmailBind(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Role == common.RoleRootUser {
|
||||
if user.Role == model.RoleRootUser {
|
||||
config.RootUserEmail = email
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -767,3 +771,38 @@ func TopUp(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type adminTopUpRequest struct {
|
||||
UserId int `json:"user_id"`
|
||||
Quota int `json:"quota"`
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
|
||||
func AdminTopUp(c *gin.Context) {
|
||||
req := adminTopUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.IncreaseUserQuota(req.UserId, int64(req.Quota))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.Remark == "" {
|
||||
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
|
||||
}
|
||||
model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user