Compare commits

..

5 Commits

Author SHA1 Message Date
jinjianming
de5faa145f Merge ade00eed1c into f9774698e9 2024-08-08 11:04:12 +08:00
jinjianmingming
ade00eed1c fix: 修复Update未能预期执行问题 2024-07-19 10:00:58 +08:00
jinjianmingming
66d31102d3 feat: 调整为定时每晚00:00
feat: 添加Redis更新缓存操作
fate: expiration_date设置默认值0
2024-07-18 22:58:04 +08:00
jinjianmingming
1aba1e4ac3 feat: 优化用户等级刷新测试,默认24h扫描一次,可以通过TIMER_FREQUENCY设置时间单位是小时 2024-07-18 10:39:21 +08:00
jinjianmingming
b86a3acd37 feat: 引入到期时间概念,当非default的分组时将有到期时间概念
fix: 修复旧用户的vip被降级,尽量不对原先用户做变动
2024-07-17 16:18:55 +08:00
45 changed files with 432 additions and 1014 deletions

View File

@@ -35,7 +35,6 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var OidcEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
@@ -71,13 +70,6 @@ var GitHubClientSecret = ""
var LarkClientId = ""
var LarkClientSecret = ""
var OidcClientId = ""
var OidcClientSecret = ""
var OidcWellKnown = ""
var OidcAuthorizationEndpoint = ""
var OidcTokenEndpoint = ""
var OidcUserinfoEndpoint = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
@@ -107,6 +99,7 @@ var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
var TimerFrequency = env.Int("TIMER_FREQUENCY", 24) // unit is hour
var BatchUpdateEnabled = false
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)

View File

@@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
} else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
err = c.ShouldBind(&v)
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
}
if err != nil {
return err
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}

View File

@@ -1,225 +0,0 @@
package auth
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{
"client_id": config.OidcClientId,
"client_secret": config.OidcClientSecret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if config.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
controller.SetupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -17,11 +17,9 @@ func GetSubscription(c *gin.Context) {
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt(ctxkey.TokenId)
token, err = model.GetTokenById(tokenId)
if err == nil {
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
}
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
} else {
userId := c.GetInt(ctxkey.Id)
remainQuota, err = model.GetUserQuota(userId)

View File

@@ -81,26 +81,6 @@ type APGC2DGPTUsageResponse struct {
TotalUsed float64 `json:"total_used"`
}
type SiliconFlowUsageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Status bool `json:"status"`
Data struct {
ID string `json:"id"`
Name string `json:"name"`
Image string `json:"image"`
Email string `json:"email"`
IsAdmin bool `json:"isAdmin"`
Balance string `json:"balance"`
Status string `json:"status"`
Introduction string `json:"introduction"`
Role string `json:"role"`
ChargeBalance string `json:"chargeBalance"`
TotalBalance string `json:"totalBalance"`
Category string `json:"category"`
} `json:"data"`
}
// GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header {
h := http.Header{}
@@ -223,28 +203,6 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
return response.TotalAvailable, nil
}
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
url := "https://api.siliconflow.cn/v1/user/info"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := SiliconFlowUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Code != 20000 {
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
}
balance, err := strconv.ParseFloat(response.Data.Balance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := channeltype.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
@@ -269,8 +227,6 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelAPI2GPTBalance(channel)
case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel)
case channeltype.SiliconFlow:
return updateChannelSiliconFlowBalance(channel)
default:
return 0, errors.New("尚未实现")
}

View File

@@ -18,30 +18,24 @@ func GetStatus(c *gin.Context) {
"success": true,
"message": "",
"data": gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
"lark_client_id": config.LarkClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
"wechat_login": config.WeChatAuthEnabled,
"server_address": config.ServerAddress,
"turnstile_check": config.TurnstileCheckEnabled,
"turnstile_site_key": config.TurnstileSiteKey,
"top_up_link": config.TopUpLink,
"chat_link": config.ChatLink,
"quota_per_unit": config.QuotaPerUnit,
"display_in_currency": config.DisplayInCurrencyEnabled,
"oidc": config.OidcEnabled,
"oidc_client_id": config.OidcClientId,
"oidc_well_known": config.OidcWellKnown,
"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint,
"oidc_token_endpoint": config.OidcTokenEndpoint,
"oidc_userinfo_endpoint": config.OidcUserinfoEndpoint,
"version": common.Version,
"start_time": common.StartTime,
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
"lark_client_id": config.LarkClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
"wechat_login": config.WeChatAuthEnabled,
"server_address": config.ServerAddress,
"turnstile_check": config.TurnstileCheckEnabled,
"turnstile_site_key": config.TurnstileSiteKey,
"top_up_link": config.TopUpLink,
"chat_link": config.ChatLink,
"quota_per_unit": config.QuotaPerUnit,
"display_in_currency": config.DisplayInCurrencyEnabled,
},
})
return

3
go.mod
View File

@@ -55,6 +55,7 @@ require (
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-co-op/gocron v1.37.0 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
@@ -87,6 +88,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/smarty/assertions v1.15.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
@@ -96,6 +98,7 @@ require (
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect

22
go.sum
View File

@@ -31,6 +31,7 @@ github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
@@ -67,6 +68,8 @@ github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0Nglqm
github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-co-op/gocron v1.37.0 h1:ZYDJGtQ4OMhTLKOKMIch+/CY70Brbb1dGdooLEhh7b0=
github.com/go-co-op/gocron v1.37.0/go.mod h1:3L/n6BkO7ABj+TrfSVXLRzsP26zmikL4ISkLQ0O8iNY=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -116,6 +119,7 @@ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
@@ -156,7 +160,12 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
@@ -177,6 +186,7 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
@@ -184,7 +194,12 @@ github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYde
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
@@ -198,6 +213,7 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
@@ -217,6 +233,8 @@ go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGX
go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco=
go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
@@ -266,6 +284,7 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.187.0 h1:Mxs7VATVC2v7CY+7Xwm4ndkX71hpElcvx0D1Ji/p1eo=
google.golang.org/api v0.187.0/go.mod h1:KIHlTc4x7N7gKKuVsdmfBXN13yEEWXWFURWY6SBp2gk=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
@@ -296,7 +315,10 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -73,6 +73,8 @@ func main() {
go model.SyncOptions(config.SyncFrequency)
go model.SyncChannelCache(config.SyncFrequency)
}
go model.ScheduleCheckAndDowngrade()
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil {
@@ -112,4 +114,5 @@ func main() {
if err != nil {
logger.FatalLog("failed to start HTTP server: " + err.Error())
}
}

View File

@@ -12,7 +12,7 @@ import (
)
type ModelRequest struct {
Model string `json:"model" form:"model"`
Model string `json:"model"`
}
func Distribute() func(c *gin.Context) {

View File

@@ -3,7 +3,6 @@ package model
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
@@ -153,11 +152,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
ifnull := "ifnull"
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -181,11 +176,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
ifnull := "ifnull"
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}

View File

@@ -1,9 +1,11 @@
package model
import (
"github.com/go-co-op/gocron"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"log"
"strconv"
"strings"
"time"
@@ -28,7 +30,6 @@ func InitOptionMap() {
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
@@ -79,6 +80,21 @@ func InitOptionMap() {
loadOptionsFromDatabase()
}
func ScheduleCheckAndDowngrade() {
s := gocron.NewScheduler(time.UTC)
// 设置每天0点执行
_, err := s.Every(1).Day().At("00:00").Do(checkAndDowngradeUsers)
if err != nil {
log.Fatalf("创建调度任务失败: %v", err)
}
// 开始调度
s.StartBlocking()
log.Printf("开始调度")
}
func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
@@ -131,8 +147,6 @@ func updateOptionMap(key string, value string) (err error) {
config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
config.GitHubOAuthEnabled = boolValue
case "OidcEnabled":
config.OidcEnabled = boolValue
case "WeChatAuthEnabled":
config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled":
@@ -179,18 +193,6 @@ func updateOptionMap(key string, value string) (err error) {
config.LarkClientId = value
case "LarkClientSecret":
config.LarkClientSecret = value
case "OidcClientId":
config.OidcClientId = value
case "OidcClientSecret":
config.OidcClientSecret = value
case "OidcWellKnown":
config.OidcWellKnown = value
case "OidcAuthorizationEndpoint":
config.OidcAuthorizationEndpoint = value
case "OidcTokenEndpoint":
config.OidcTokenEndpoint = value
case "OidcUserinfoEndpoint":
config.OidcUserinfoEndpoint = value
case "Footer":
config.Footer = value
case "SystemName":

View File

@@ -30,7 +30,7 @@ type Token struct {
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Models *string `json:"models" gorm:"type:text"` // allowed models
Models *string `json:"models" gorm:"default:''"` // allowed models
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
}
@@ -121,40 +121,30 @@ func GetTokenById(id int) (*Token, error) {
return &token, err
}
func (t *Token) Insert() error {
func (token *Token) Insert() error {
var err error
err = DB.Create(t).Error
err = DB.Create(token).Error
return err
}
// Update Make sure your token's fields is completed, because this will update non-zero values
func (t *Token) Update() error {
func (token *Token) Update() error {
var err error
err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
return err
}
func (t *Token) SelectUpdate() error {
func (token *Token) SelectUpdate() error {
// This can update zero values
return DB.Model(t).Select("accessed_time", "status").Updates(t).Error
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
}
func (t *Token) Delete() error {
func (token *Token) Delete() error {
var err error
err = DB.Delete(t).Error
err = DB.Delete(token).Error
return err
}
func (t *Token) GetModels() string {
if t == nil {
return ""
}
if t.Models == nil {
return ""
}
return *t.Models
}
func DeleteTokenById(id int, userId int) (err error) {
// Why we need userId here? In case user want to delete other's token.
if id == 0 || userId == 0 {
@@ -264,14 +254,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
token, err := GetTokenById(tokenId)
if err != nil {
return err
}
if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)
} else {
err = IncreaseUserQuota(token.UserId, -quota)
}
if err != nil {
return err
}
if !token.UnlimitedQuota {
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)

View File

@@ -10,7 +10,9 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm"
"log"
"strings"
"time"
)
const (
@@ -39,7 +41,6 @@ type User struct {
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"bigint;default:0"`
@@ -48,6 +49,7 @@ type User struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
ExpirationDate int64 `json:"expiration_date" gorm:"bigint;default:0;column:expiration_date"` // Expiration date of the user's subscription or account.
}
func GetMaxUserId() int {
@@ -211,6 +213,25 @@ func (user *User) ValidateAndFill() (err error) {
if !okay || user.Status != UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
}
// 校验用户是不是非default,如果是非default,判断到期时间如果过期了降级为default
if !(user.ExpirationDate > 0 && user.Username == "root") {
// 将时间戳转换为 time.Time 类型
expirationTime := time.Unix(user.ExpirationDate, 0)
// 获取当前时间
currentTime := time.Now()
// 比较当前时间和到期时间
if expirationTime.Before(currentTime) {
// 降级为 default
user.Group = "default"
err := DB.Model(user).Updates(user).Error
if err != nil {
fmt.Printf("用户: %s, 降级为 default 时发生错误: %v\n", user.Username, err)
return err
}
fmt.Printf("用户: %s, 特权组过期降为 default\n", user.Username)
}
}
return nil
}
@@ -246,14 +267,6 @@ func (user *User) FillUserByLarkId() error {
return nil
}
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
}
DB.Where(User{OidcId: user.OidcId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -286,10 +299,6 @@ func IsLarkIdAlreadyTaken(githubId string) bool {
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
@@ -448,3 +457,48 @@ func GetUsernameById(id int) (username string) {
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
return username
}
func checkAndDowngradeUsers() {
// 获取昨天的时间戳
yesterdayTimestamp := time.Now().AddDate(0, 0, -1).Unix()
// 获取需要降级的用户ID列表
var userList []int
query := DB.Model(&User{}).
Where("`Group` != ?", "default").
Where("`username` != ?", "root").
Where("`expiration_date` > 0").
Where("`expiration_date` <= ?", yesterdayTimestamp).
Select("id").
Find(&userList)
// 处理查询错误
if query.Error != nil {
log.Printf("查询用户列表失败: %v", query.Error)
return
}
// 如果没有用户需要降级,直接返回
if len(userList) == 0 {
return
}
// 批量降级用户
updateQuery := DB.Model(&User{}).Where("id IN ?", userList).Update("Group", "default")
// 处理更新错误
if updateQuery.Error != nil {
log.Printf("批量更新用户分组失败: %v", updateQuery.Error)
return
}
// 删除已过期用户的Redis缓存
if common.RedisEnabled {
for _, userId := range userList {
err := common.RedisSet(fmt.Sprintf("user_group:%d", userId), "default", time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
log.Printf("更新用户: %d, 权益缓存失败, Error: %v", userId, err)
}
}
}
}

View File

@@ -1,11 +1,10 @@
package monitor
import (
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model"
"net/http"
"strings"
)
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
@@ -19,23 +18,31 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
return true
}
switch err.Type {
case "insufficient_quota", "authentication_error", "permission_error", "forbidden":
case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
return true
}
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
lowerMessage := strings.ToLower(err.Message)
if strings.Contains(lowerMessage, "your access was terminated") ||
strings.Contains(lowerMessage, "violation of our policies") ||
strings.Contains(lowerMessage, "your credit balance is too low") ||
strings.Contains(lowerMessage, "organization has been disabled") ||
strings.Contains(lowerMessage, "credit") ||
strings.Contains(lowerMessage, "balance") ||
strings.Contains(lowerMessage, "permission denied") ||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
strings.Contains(lowerMessage, "已欠费") {
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
//if strings.Contains(err.Message, "quota") {
// return true
//}
if strings.Contains(err.Message, "credit") {
return true
}
if strings.Contains(err.Message, "balance") {
return true
}
return false

View File

@@ -3,7 +3,6 @@ package ali
import (
"bufio"
"encoding/json"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
@@ -60,7 +59,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: request.Model,
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
@@ -103,9 +102,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat
StatusCode: resp.StatusCode,
}, nil
}
requestModel := c.GetString(ctxkey.RequestModel)
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
fullTextResponse.Model = requestModel
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -8,8 +8,6 @@ var ModelList = []string{
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4o", "gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"chatgpt-4o-latest",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",

View File

@@ -55,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error
}
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
// but for empty choice and no usage, we should not pass it to client, this is for azure
if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice
}
render.StringData(c, data)

View File

@@ -1,13 +1,7 @@
package stepfun
var ModelList = []string{
"step-1-8k",
"step-1-32k",
"step-1-128k",
"step-1-256k",
"step-1-flash",
"step-2-16k",
"step-1v-8k",
"step-1v-32k",
"step-1x-medium",
"step-1-200k",
}

View File

@@ -5,5 +5,4 @@ var ModelList = []string{
"hunyuan-standard",
"hunyuan-standard-256K",
"hunyuan-pro",
"hunyuan-vision",
}

View File

@@ -5,7 +5,6 @@ var ModelList = []string{
"SparkDesk-v1.1",
"SparkDesk-v2.1",
"SparkDesk-v3.1",
"SparkDesk-v3.1-128K",
"SparkDesk-v3.5",
"SparkDesk-v4.0",
}

View File

@@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
}
func parseAPIVersionByModelName(modelName string) string {
index := strings.IndexAny(modelName, "-")
if index != -1 {
return modelName[index+1:]
parts := strings.Split(modelName, "-")
if len(parts) == 2 {
return parts[1]
}
return ""
}
@@ -288,8 +288,6 @@ func apiVersion2domain(apiVersion string) string {
return "generalv2"
case "v3.1":
return "generalv3"
case "v3.1-128K":
return "pro-128k"
case "v3.5":
return "generalv3.5"
case "v4.0":
@@ -299,14 +297,7 @@ func apiVersion2domain(apiVersion string) string {
}
func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
var authUrl string
domain := apiVersion2domain(apiVersion)
switch apiVersion {
case "v3.1-128K":
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/pro-128k", apiVersion), apiKey, apiSecret)
break
default:
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
}
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl
}

View File

@@ -30,14 +30,6 @@ var ImageSizeRatios = map[string]map[string]float64{
"720x1280": 1,
"1280x720": 1,
},
"step-1x-medium": {
"256x256": 1,
"512x512": 1,
"768x768": 1,
"1024x1024": 1,
"1280x800": 1,
"800x1280": 1,
},
}
var ImageGenerationAmounts = map[string][2]int{
@@ -47,7 +39,6 @@ var ImageGenerationAmounts = map[string][2]int{
"ali-stable-diffusion-v1.5": {1, 4}, // Ali
"wanx-v1": {1, 4}, // Ali
"cogview-3": {1, 1},
"step-1x-medium": {1, 1},
}
var ImagePromptLengthLimitations = map[string]int{
@@ -57,7 +48,6 @@ var ImagePromptLengthLimitations = map[string]int{
"ali-stable-diffusion-v1.5": 4000,
"wanx-v1": 4000,
"cogview-3": 833,
"step-1x-medium": 4000,
}
var ImageOriginModelName = map[string]string{

View File

@@ -34,9 +34,7 @@ var ModelRatio = map[string]float64{
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.005 / 1K tokens
"chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
@@ -128,7 +126,6 @@ var ModelRatio = map[string]float64{
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
@@ -174,15 +171,10 @@ var ModelRatio = map[string]float64{
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
"yi-vl-plus": 6.0 / 1000 * RMB,
// https://platform.stepfun.com/docs/pricing/details
"step-1-8k": 0.005 / 1000 * RMB,
"step-1-32k": 0.015 / 1000 * RMB,
"step-1-128k": 0.040 / 1000 * RMB,
"step-1-256k": 0.095 / 1000 * RMB,
"step-1-flash": 0.001 / 1000 * RMB,
"step-2-16k": 0.038 / 1000 * RMB,
"step-1v-8k": 0.005 / 1000 * RMB,
"step-1v-32k": 0.015 / 1000 * RMB,
// stepfun todo
"step-1v-32k": 0.024 * RMB,
"step-1-32k": 0.024 * RMB,
"step-1-200k": 0.15 * RMB,
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
@@ -208,10 +200,8 @@ var CompletionRatio = map[string]float64{
"llama3-70b-8192(33)": 0.0035 / 0.00265,
}
var (
DefaultModelRatio map[string]float64
DefaultCompletionRatio map[string]float64
)
var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64
func init() {
DefaultModelRatio = make(map[string]float64)
@@ -323,7 +313,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" {
if strings.HasPrefix(name, "gpt-4o-mini") {
return 4
}
if strings.HasPrefix(name, "gpt-4-turbo") ||
@@ -333,9 +323,6 @@ func GetCompletionRatio(name string, channelType int) float64 {
}
return 2
}
if name == "chatgpt-4o-latest" {
return 3
}
if strings.HasPrefix(name, "claude-3") {
return 5
}

View File

@@ -1,15 +1,7 @@
package model
type ResponseFormat struct {
Type string `json:"type,omitempty"`
JsonSchema *JSONSchema `json:"json_schema,omitempty"`
}
type JSONSchema struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Schema map[string]interface{} `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"`
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {

View File

@@ -23,7 +23,6 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth)
apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)

View File

@@ -11,14 +11,12 @@ import EditToken from '../pages/Token/EditToken';
const COPY_OPTIONS = [
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' }
];
const OPEN_LINK_OPTIONS = [
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' }
{ key: 'opencat', text: 'OpenCat', value: 'opencat' }
];
function renderTimestamp(timestamp) {
@@ -62,12 +60,7 @@ const TokensTable = () => {
onOpenLink('next-mj');
}
},
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' },
{
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
onOpenLink('lobechat');
}
}
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' }
];
const columns = [
@@ -184,11 +177,6 @@ const TokensTable = () => {
node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => {
onOpenLink('opencat', record.key);
}
},
{
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
onOpenLink('lobechat');
}
}
]
}
@@ -394,9 +382,6 @@ const TokensTable = () => {
case 'next-mj':
url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
break;
case 'lobechat':
url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}"/v1"}}}`;
break;
default:
if (!chatLink) {
showError('管理员未设置聊天链接');

View File

@@ -78,7 +78,7 @@ const EditChannel = (props) => {
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
break;
case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];

View File

@@ -17,6 +17,7 @@
"@tabler/icons-react": "^2.44.0",
"apexcharts": "3.35.3",
"axios": "^0.27.2",
"date-fns": "^3.6.0",
"dayjs": "^1.11.10",
"formik": "^2.2.9",
"framer-motion": "^6.3.16",
@@ -27,6 +28,7 @@
"prop-types": "^15.8.1",
"react": "^18.2.0",
"react-apexcharts": "1.4.0",
"react-datepicker": "^7.3.0",
"react-device-detect": "^2.2.2",
"react-dom": "^18.2.0",
"react-perfect-scrollbar": "^1.5.8",

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 4.3 KiB

After

Width:  |  Height:  |  Size: 5.4 KiB

View File

@@ -1,7 +0,0 @@
<svg t="1723135116886" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg"
p-id="10969" width="200" height="200">
<path d="M512 960C265 960 64 759 64 512S265 64 512 64s448 201 448 448-201 448-448 448z m0-882.6c-239.7 0-434.6 195-434.6 434.6s195 434.6 434.6 434.6 434.6-195 434.6-434.6S751.7 77.4 512 77.4z"
p-id="10970" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="60"></path>
<path d="M197.7 512c0-78.3 31.6-98.8 87.2-98.8 56.2 0 87.2 20.5 87.2 98.8s-31 98.8-87.2 98.8c-55.7 0-87.2-20.5-87.2-98.8z m130.4 0c0-46.8-7.8-64.5-43.2-64.5-35.2 0-42.9 17.7-42.9 64.5 0 47.1 7.8 63.7 42.9 63.7 35.4 0 43.2-16.6 43.2-63.7zM409.7 415.9h42.1V608h-42.1V415.9zM653.9 512c0 74.2-37.1 96.1-93.6 96.1h-65.9V415.9h65.9c56.5 0 93.6 16.1 93.6 96.1z m-43.5 0c0-49.3-17.7-60.6-52.3-60.6h-21.6v120.7h21.6c35.4 0 52.3-13.3 52.3-60.1zM686.5 512c0-74.2 36.3-98.8 92.7-98.8 18.3 0 33.2 2.2 44.8 6.4v36.3c-11.9-4.2-26-6.6-42.1-6.6-34.6 0-49.8 15.5-49.8 62.6 0 50.1 15.2 62.6 49.3 62.6 15.8 0 30.2-2.2 44.8-7.5v36c-11.3 4.7-28.5 8-46.8 8-56.1-0.2-92.9-18.7-92.9-99z"
p-id="10971" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="20"></path>
</svg>

Before

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -22,12 +22,7 @@ const config = {
turnstile_site_key: '',
version: '',
wechat_login: false,
wechat_qrcode: '',
oidc: false,
oidc_client_id: '',
oidc_authorization_endpoint: '',
oidc_token_endpoint: '',
oidc_userinfo_endpoint: '',
wechat_qrcode: ''
}
};

View File

@@ -70,28 +70,6 @@ const useLogin = () => {
}
};
const oidcLogin = async (code, state) => {
try {
const res = await API.get(`/api/oauth/oidc?code=${code}&state=${state}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/panel');
} else {
dispatch({ type: LOGIN, payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/panel');
}
}
return { success, message };
} catch (err) {
// 请求失败,设置错误信息
return { success: false, message: '' };
}
}
const wechatLogin = async (code) => {
try {
const res = await API.get(`/api/oauth/wechat?code=${code}`);
@@ -116,7 +94,7 @@ const useLogin = () => {
navigate('/');
};
return { login, logout, githubLogin, wechatLogin, larkLogin,oidcLogin };
return { login, logout, githubLogin, wechatLogin, larkLogin };
};
export default useLogin;

View File

@@ -9,7 +9,6 @@ const AuthLogin = Loadable(lazy(() => import('views/Authentication/Auth/Login'))
const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register')));
const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth')));
const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth')));
const OidcOAuth = Loadable(lazy(() => import('views/Authentication/Auth/OidcOAuth')));
const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword')));
const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword')));
const Home = Loadable(lazy(() => import('views/Home')));
@@ -54,10 +53,6 @@ const OtherRoutes = {
path: '/oauth/lark',
element: <LarkOAuth />
},
{
path: 'oauth/oidc',
element: <OidcOAuth />
},
{
path: '/404',
element: <NotFoundView />

View File

@@ -98,21 +98,6 @@ export async function onLarkOAuthClicked(lark_client_id) {
window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`);
}
export async function onOidcClicked(auth_url, client_id, openInNewTab = false) {
const state = await getOAuthState();
if (!state) return;
const redirect_uri = `${window.location.origin}/oauth/oidc`;
const response_type = "code";
const scope = "openid profile email";
const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`;
if (openInNewTab) {
window.open(url);
} else
{
window.location.href = url;
}
}
export function isAdmin() {
let user = localStorage.getItem('user');
if (!user) return false;

View File

@@ -1,94 +0,0 @@
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import React, { useEffect, useState } from 'react';
import { showError } from 'utils/common';
import useLogin from 'hooks/useLogin';
// material-ui
import { useTheme } from '@mui/material/styles';
import { Grid, Stack, Typography, useMediaQuery, CircularProgress } from '@mui/material';
// project imports
import AuthWrapper from '../AuthWrapper';
import AuthCardWrapper from '../AuthCardWrapper';
import Logo from 'ui-component/Logo';
// assets
// ================================|| AUTH3 - LOGIN ||================================ //
const OidcOAuth = () => {
const theme = useTheme();
const matchDownSM = useMediaQuery(theme.breakpoints.down('md'));
const [searchParams] = useSearchParams();
const [prompt, setPrompt] = useState('处理中...');
const { oidcLogin } = useLogin();
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
const { success, message } = await oidcLogin(code, state);
if (!success) {
if (message) {
showError(message);
}
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
await new Promise((resolve) => setTimeout(resolve, 2000));
navigate('/login');
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, 2000));
await sendCode(code, state, count);
}
};
useEffect(() => {
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (
<AuthWrapper>
<Grid container direction="column" justifyContent="flex-end">
<Grid item xs={12}>
<Grid container justifyContent="center" alignItems="center" sx={{ minHeight: 'calc(100vh - 136px)' }}>
<Grid item sx={{ m: { xs: 1, sm: 3 }, mb: 0 }}>
<AuthCardWrapper>
<Grid container spacing={2} alignItems="center" justifyContent="center">
<Grid item sx={{ mb: 3 }}>
<Link to="#">
<Logo />
</Link>
</Grid>
<Grid item xs={12}>
<Grid container direction={matchDownSM ? 'column-reverse' : 'row'} alignItems="center" justifyContent="center">
<Grid item>
<Stack alignItems="center" justifyContent="center" spacing={1}>
<Typography color={theme.palette.primary.main} gutterBottom variant={matchDownSM ? 'h3' : 'h2'}>
OIDC 登录
</Typography>
</Stack>
</Grid>
</Grid>
</Grid>
<Grid item xs={12} container direction="column" justifyContent="center" alignItems="center" style={{ height: '200px' }}>
<CircularProgress />
<Typography variant="h3" paddingTop={'20px'}>
{prompt}
</Typography>
</Grid>
</Grid>
</AuthCardWrapper>
</Grid>
</Grid>
</Grid>
</Grid>
</AuthWrapper>
);
};
export default OidcOAuth;

View File

@@ -36,8 +36,7 @@ import VisibilityOff from '@mui/icons-material/VisibilityOff';
import Github from 'assets/images/icons/github.svg';
import Wechat from 'assets/images/icons/wechat.svg';
import Lark from 'assets/images/icons/lark.svg';
import OIDC from 'assets/images/icons/oidc.svg';
import { onGitHubOAuthClicked, onLarkOAuthClicked, onOidcClicked } from 'utils/common';
import { onGitHubOAuthClicked, onLarkOAuthClicked } from 'utils/common';
// ============================|| FIREBASE - LOGIN ||============================ //
@@ -51,7 +50,7 @@ const LoginForm = ({ ...others }) => {
// const [checked, setChecked] = useState(true);
let tripartiteLogin = false;
if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id || siteInfo.oidc) {
if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id) {
tripartiteLogin = true;
}
@@ -146,29 +145,6 @@ const LoginForm = ({ ...others }) => {
</AnimateButton>
</Grid>
)}
{siteInfo.oidc && (
<Grid item xs={12}>
<AnimateButton>
<Button
disableElevation
fullWidth
onClick={() => onOidcClicked(siteInfo.oidc_authorization_endpoint,siteInfo.oidc_client_id)}
size="large"
variant="outlined"
sx={{
color: 'grey.700',
backgroundColor: theme.palette.grey[50],
borderColor: theme.palette.grey[100]
}}
>
<Box sx={{ mr: { xs: 1, sm: 2, width: 20 }, display: 'flex', alignItems: 'center' }}>
<img src={OIDC} alt="Lark" width={25} height={25} style={{ marginRight: matchDownSM ? 8 : 16 }} />
</Box>
使用 OIDC 登录
</Button>
</AnimateButton>
</Grid>
)}
<Grid item xs={12}>
<Box
sx={{

View File

@@ -268,8 +268,6 @@ function renderBalance(type, balance) {
return <span>¥{balance.toFixed(2)}</span>;
case 13: // AIGC2D
return <span>{renderNumber(balance)}</span>;
case 44: // SiliconFlow
return <span>¥{balance.toFixed(2)}</span>;
default:
return <span>不支持</span>;
}

View File

@@ -91,7 +91,7 @@ const typeConfig = {
other: '版本号'
},
input: {
models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v4.0']
models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']
},
prompt: {
key: '按照如下格式输入APPID|APISecret|APIKey',

View File

@@ -20,7 +20,7 @@ import SubCard from 'ui-component/cards/SubCard';
import { IconBrandWechat, IconBrandGithub, IconMail } from '@tabler/icons-react';
import Label from 'ui-component/Label';
import { API } from 'utils/api';
import { onOidcClicked, showError, showSuccess } from 'utils/common';
import { showError, showSuccess } from 'utils/common';
import { onGitHubOAuthClicked, onLarkOAuthClicked, copy } from 'utils/common';
import * as Yup from 'yup';
import WechatModal from 'views/Authentication/AuthForms/WechatModal';
@@ -28,7 +28,6 @@ import { useSelector } from 'react-redux';
import EmailModal from './component/EmailModal';
import Turnstile from 'react-turnstile';
import { ReactComponent as Lark } from 'assets/images/icons/lark.svg';
import { ReactComponent as OIDC } from 'assets/images/icons/oidc.svg';
const validationSchema = Yup.object().shape({
username: Yup.string().required('用户名 不能为空').min(3, '用户名 不能小于 3 个字符'),
@@ -124,15 +123,6 @@ export default function Profile() {
loadUser().then();
}, [status]);
function getOidcId(){
if (!inputs.oidc_id) return '';
let oidc_id = inputs.oidc_id;
if (inputs.oidc_id.length > 8) {
oidc_id = inputs.oidc_id.slice(0, 6) + '...' + inputs.oidc_id.slice(-6);
}
return oidc_id;
}
return (
<>
<UserCard>
@@ -151,9 +141,6 @@ export default function Profile() {
<Label variant="ghost" color={inputs.lark_id ? 'primary' : 'default'}>
<SvgIcon component={Lark} inheritViewBox="0 0 24 24" /> {inputs.lark_id || '未绑定'}
</Label>
<Label variant="ghost" color={inputs.oidc_id ? 'primary' : 'default'}>
<SvgIcon component={OIDC} inheritViewBox="0 0 24 24" /> {getOidcId() || '未绑定'}
</Label>
</Stack>
<SubCard title="个人信息">
<Grid container spacing={2}>
@@ -229,13 +216,6 @@ export default function Profile() {
</Button>
</Grid>
)}
{status.oidc && !inputs.oidc_id && (
<Grid xs={12} md={4}>
<Button variant="contained" onClick={() => onOidcClicked(status.oidc_authorization_endpoint,status.oidc_client_id,true)}>
绑定 OIDC 账号
</Button>
</Grid>
)}
<Grid xs={12} md={4}>
<Button
variant="contained"

View File

@@ -33,13 +33,6 @@ const SystemSetting = () => {
GitHubClientSecret: '',
LarkClientId: '',
LarkClientSecret: '',
OidcEnabled: '',
OidcWellKnown: '',
OidcClientId: '',
OidcClientSecret: '',
OidcAuthorizationEndpoint: '',
OidcTokenEndpoint: '',
OidcUserinfoEndpoint: '',
Notice: '',
SMTPServer: '',
SMTPPort: '',
@@ -101,7 +94,6 @@ const SystemSetting = () => {
case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled':
case 'RegisterEnabled':
case 'OidcEnabled':
value = inputs[key] === 'true' ? 'false' : 'true';
break;
default:
@@ -150,15 +142,8 @@ const SystemSetting = () => {
name === 'MessagePusherAddress' ||
name === 'MessagePusherToken' ||
name === 'LarkClientId' ||
name === 'LarkClientSecret' ||
name === 'OidcClientId' ||
name === 'OidcClientSecret' ||
name === 'OidcWellKnown' ||
name === 'OidcAuthorizationEndpoint' ||
name === 'OidcTokenEndpoint' ||
name === 'OidcUserinfoEndpoint'
)
{
name === 'LarkClientSecret'
) {
setInputs((inputs) => ({ ...inputs, [name]: value }));
} else {
await updateOption(name, value);
@@ -240,43 +225,6 @@ const SystemSetting = () => {
}
};
const submitOidc = async () => {
if (inputs.OidcWellKnown !== '') {
if (!inputs.OidcWellKnown.startsWith('http://') && !inputs.OidcWellKnown.startsWith('https://')) {
showError('Well-Known URL 必须以 http:// 或 https:// 开头');
return;
}
try {
const res = await API.get(inputs.OidcWellKnown);
inputs.OidcAuthorizationEndpoint = res.data['authorization_endpoint'];
inputs.OidcTokenEndpoint = res.data['token_endpoint'];
inputs.OidcUserinfoEndpoint = res.data['userinfo_endpoint'];
showSuccess('获取 OIDC 配置成功!');
} catch (err) {
showError("获取 OIDC 配置失败,请检查网络状况和 Well-Known URL 是否正确");
}
}
if (originInputs['OidcWellKnown'] !== inputs.OidcWellKnown) {
await updateOption('OidcWellKnown', inputs.OidcWellKnown);
}
if (originInputs['OidcClientId'] !== inputs.OidcClientId) {
await updateOption('OidcClientId', inputs.OidcClientId);
}
if (originInputs['OidcClientSecret'] !== inputs.OidcClientSecret && inputs.OidcClientSecret !== '') {
await updateOption('OidcClientSecret', inputs.OidcClientSecret);
}
if (originInputs['OidcAuthorizationEndpoint'] !== inputs.OidcAuthorizationEndpoint) {
await updateOption('OidcAuthorizationEndpoint', inputs.OidcAuthorizationEndpoint);
}
if (originInputs['OidcTokenEndpoint'] !== inputs.OidcTokenEndpoint) {
await updateOption('OidcTokenEndpoint', inputs.OidcTokenEndpoint);
}
if (originInputs['OidcUserinfoEndpoint'] !== inputs.OidcUserinfoEndpoint) {
await updateOption('OidcUserinfoEndpoint', inputs.OidcUserinfoEndpoint);
}
};
return (
<>
<Stack spacing={2}>
@@ -343,12 +291,6 @@ const SystemSetting = () => {
control={<Checkbox checked={inputs.GitHubOAuthEnabled === 'true'} onChange={handleInputChange} name="GitHubOAuthEnabled" />}
/>
</Grid>
<Grid xs={12} md={3}>
<FormControlLabel
label="允许通过 OIDC 登录 & 注册"
control={<Checkbox checked={inputs.OidcEnabled === 'true'} onChange={handleInputChange} name="OidcEnabled" />}
/>
</Grid>
<Grid xs={12} md={3}>
<FormControlLabel
label="允许通过微信登录 & 注册"
@@ -674,117 +616,6 @@ const SystemSetting = () => {
</Grid>
</Grid>
</SubCard>
<SubCard
title="配置 OIDC"
subTitle={
<span>
用以支持通过 OIDC 登录例如 OktaAuth0 等兼容 OIDC 协议的 IdP
</span>
}
>
<Grid container spacing={ { xs: 3, sm: 2, md: 4 } }>
<Grid xs={ 12 } md={ 12 }>
<Alert severity="info" sx={ { wordWrap: 'break-word' } }>
主页链接填 <code>{ inputs.ServerAddress }</code>
重定向 URL <code>{ `${ inputs.ServerAddress }/oauth/oidc` }</code>
</Alert> <br />
<Alert severity="info" sx={ { wordWrap: 'break-word' } }>
若你的 OIDC Provider 支持 Discovery Endpoint你可以仅填写 OIDC Well-Known URL系统会自动获取 OIDC 配置
</Alert>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcClientId">Client ID</InputLabel>
<OutlinedInput
id="OidcClientId"
name="OidcClientId"
value={ inputs.OidcClientId || '' }
onChange={ handleInputChange }
label="Client ID"
placeholder="输入 OIDC 的 Client ID"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcClientSecret">Client Secret</InputLabel>
<OutlinedInput
id="OidcClientSecret"
name="OidcClientSecret"
value={ inputs.OidcClientSecret || '' }
onChange={ handleInputChange }
label="Client Secret"
placeholder="敏感信息不会发送到前端显示"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcWellKnown">Well-Known URL</InputLabel>
<OutlinedInput
id="OidcWellKnown"
name="OidcWellKnown"
value={ inputs.OidcWellKnown || '' }
onChange={ handleInputChange }
label="Well-Known URL"
placeholder="请输入 OIDC 的 Well-Known URL"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcAuthorizationEndpoint">Authorization Endpoint</InputLabel>
<OutlinedInput
id="OidcAuthorizationEndpoint"
name="OidcAuthorizationEndpoint"
value={ inputs.OidcAuthorizationEndpoint || '' }
onChange={ handleInputChange }
label="Authorization Endpoint"
placeholder="输入 OIDC 的 Authorization Endpoint"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcTokenEndpoint">Token Endpoint</InputLabel>
<OutlinedInput
id="OidcTokenEndpoint"
name="OidcTokenEndpoint"
value={ inputs.OidcTokenEndpoint || '' }
onChange={ handleInputChange }
label="Token Endpoint"
placeholder="输入 OIDC 的 Token Endpoint"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcUserinfoEndpoint">Userinfo Endpoint</InputLabel>
<OutlinedInput
id="OidcUserinfoEndpoint"
name="OidcUserinfoEndpoint"
value={ inputs.OidcUserinfoEndpoint || '' }
onChange={ handleInputChange }
label="Userinfo Endpoint"
placeholder="输入 OIDC 的 Userinfo Endpoint"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 }>
<Button variant="contained" onClick={ submitOidc }>
保存 OIDC 设置
</Button>
</Grid>
</Grid>
</SubCard>
<SubCard
title="配置 Message Pusher"
subTitle={

View File

@@ -32,8 +32,7 @@ const COPY_OPTIONS = [
encode: false
},
{ key: 'ama', text: 'BotGem', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true },
{ key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true },
{ key: 'lobechat', text: 'LobeChat', url: 'https://lobehub.com/?settings={"keyVaults":{"openai":{"apiKey":"user-key","baseURL":"https://your-proxy.com/v1"}}}', encode: true }
{ key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true }
];
function replacePlaceholders(text, key, serverAddress) {

View File

@@ -2,7 +2,9 @@ import PropTypes from 'prop-types';
import * as Yup from 'yup';
import { Formik } from 'formik';
import { useTheme } from '@mui/material/styles';
import { useState, useEffect } from 'react';
import { useState, useEffect, useCallback } from 'react';
import { format } from 'date-fns';
import {
Dialog,
DialogTitle,
@@ -17,7 +19,11 @@ import {
Select,
MenuItem,
IconButton,
FormHelperText
FormHelperText,
TextField,
Typography,
Switch,
FormControlLabel
} from '@mui/material';
import Visibility from '@mui/icons-material/Visibility';
@@ -44,6 +50,17 @@ const validationSchema = Yup.object().shape({
is: false,
then: Yup.number().min(0, '额度 不能小于 0'),
otherwise: Yup.number()
}),
expiration_date: Yup.mixed().when('group', {
is: (group) => group !== 'default',
then: Yup.mixed().test(
'expiration_date-required',
'到期时间 不能为空',
function (value) {
const { expiration_date } = this.parent;
return expiration_date === -1 || !!expiration_date;
}
),
})
});
@@ -53,7 +70,8 @@ const originInputs = {
display_name: '',
password: '',
group: 'default',
quota: 0
quota: 0,
expiration_date: null
};
const EditModal = ({ open, userId, onCancel, onOk }) => {
@@ -65,6 +83,12 @@ const EditModal = ({ open, userId, onCancel, onOk }) => {
const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
setSubmitting(true);
// 将到期时间转换为 Unix 时间戳
if (values.expiration_date && values.expiration_date !== -1) {
const date = new Date(values.expiration_date);
values.expiration_date = Math.floor(date.getTime() / 1000); // 转换为秒级的 Unix 时间戳
}
let res;
if (values.is_edit) {
res = await API.put(`/api/user/`, { ...values, id: parseInt(userId) });
@@ -95,16 +119,23 @@ const EditModal = ({ open, userId, onCancel, onOk }) => {
event.preventDefault();
};
const loadUser = async () => {
const loadUser = useCallback(async () => {
let res = await API.get(`/api/user/${userId}`);
const { success, message, data } = res.data;
if (success) {
data.is_edit = true;
// 将 Unix 时间戳转换为日期字符串
if (data.expiration_date && data.expiration_date !== -1) {
const date = new Date(data.expiration_date * 1000); // 转换为毫秒级的时间戳
data.expiration_date = format(date, 'yyyy-MM-dd'); // 格式化为 date 格式
}
setInputs(data);
} else {
showError(message);
}
};
}, [userId]);
const fetchGroups = async () => {
try {
@@ -122,159 +153,203 @@ const EditModal = ({ open, userId, onCancel, onOk }) => {
} else {
setInputs(originInputs);
}
}, [userId]);
}, [userId, loadUser]);
return (
<Dialog open={open} onClose={onCancel} fullWidth maxWidth={'md'}>
<DialogTitle sx={{ margin: '0px', fontWeight: 700, lineHeight: '1.55556', padding: '24px', fontSize: '1.125rem' }}>
{userId ? '编辑用户' : '新建用户'}
</DialogTitle>
<Divider />
<DialogContent>
<Formik initialValues={inputs} enableReinitialize validationSchema={validationSchema} onSubmit={submit}>
{({ errors, handleBlur, handleChange, handleSubmit, touched, values, isSubmitting }) => (
<form noValidate onSubmit={handleSubmit}>
<FormControl fullWidth error={Boolean(touched.username && errors.username)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-username-label">用户名</InputLabel>
<OutlinedInput
id="channel-username-label"
label="用户名"
type="text"
value={values.username}
name="username"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'username' }}
aria-describedby="helper-text-channel-username-label"
/>
{touched.username && errors.username && (
<FormHelperText error id="helper-tex-channel-username-label">
{errors.username}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.display_name && errors.display_name)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-display_name-label">显示名称</InputLabel>
<OutlinedInput
id="channel-display_name-label"
label="显示名称"
type="text"
value={values.display_name}
name="display_name"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'display_name' }}
aria-describedby="helper-text-channel-display_name-label"
/>
{touched.display_name && errors.display_name && (
<FormHelperText error id="helper-tex-channel-display_name-label">
{errors.display_name}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.password && errors.password)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-password-label">密码</InputLabel>
<OutlinedInput
id="channel-password-label"
label="密码"
type={showPassword ? 'text' : 'password'}
value={values.password}
name="password"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'password' }}
endAdornment={
<InputAdornment position="end">
<IconButton
aria-label="toggle password visibility"
onClick={handleClickShowPassword}
onMouseDown={handleMouseDownPassword}
edge="end"
size="large"
>
{showPassword ? <Visibility /> : <VisibilityOff />}
</IconButton>
</InputAdornment>
}
aria-describedby="helper-text-channel-password-label"
/>
{touched.password && errors.password && (
<FormHelperText error id="helper-tex-channel-password-label">
{errors.password}
</FormHelperText>
)}
</FormControl>
{values.is_edit && (
<>
<FormControl fullWidth error={Boolean(touched.quota && errors.quota)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-quota-label">额度</InputLabel>
<Dialog open={open} onClose={onCancel} fullWidth maxWidth={'md'}>
<DialogTitle sx={{ margin: '0px', fontWeight: 700, lineHeight: '1.55556', padding: '24px', fontSize: '1.125rem' }}>
{userId ? '编辑用户' : '新建用户'}
</DialogTitle>
<Divider />
<DialogContent>
<Formik initialValues={inputs} enableReinitialize validationSchema={validationSchema} onSubmit={submit}>
{({ errors, handleBlur, handleChange, handleSubmit, setFieldValue, touched, values, isSubmitting }) => (
<form noValidate onSubmit={handleSubmit}>
<FormControl fullWidth error={Boolean(touched.username && errors.username)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-username-label">用户名</InputLabel>
<OutlinedInput
id="channel-quota-label"
label="额度"
type="number"
value={values.quota}
name="quota"
endAdornment={<InputAdornment position="end">{renderQuotaWithPrompt(values.quota)}</InputAdornment>}
onBlur={handleBlur}
onChange={handleChange}
aria-describedby="helper-text-channel-quota-label"
disabled={values.unlimited_quota}
id="channel-username-label"
label="用户名"
type="text"
value={values.username}
name="username"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'username' }}
aria-describedby="helper-text-channel-username-label"
/>
{touched.quota && errors.quota && (
<FormHelperText error id="helper-tex-channel-quota-label">
{errors.quota}
</FormHelperText>
{touched.username && errors.username && (
<FormHelperText error id="helper-tex-channel-username-label">
{errors.username}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.group && errors.group)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-group-label">分组</InputLabel>
<Select
id="channel-group-label"
label="分组"
value={values.group}
name="group"
onBlur={handleBlur}
onChange={handleChange}
MenuProps={{
PaperProps: {
style: {
maxHeight: 200
}
<FormControl fullWidth error={Boolean(touched.display_name && errors.display_name)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-display_name-label">显示名称</InputLabel>
<OutlinedInput
id="channel-display_name-label"
label="显示名称"
type="text"
value={values.display_name}
name="display_name"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'display_name' }}
aria-describedby="helper-text-channel-display_name-label"
/>
{touched.display_name && errors.display_name && (
<FormHelperText error id="helper-tex-channel-display_name-label">
{errors.display_name}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.password && errors.password)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-password-label">密码</InputLabel>
<OutlinedInput
id="channel-password-label"
label="密码"
type={showPassword ? 'text' : 'password'}
value={values.password}
name="password"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'password' }}
endAdornment={
<InputAdornment position="end">
<IconButton
aria-label="toggle password visibility"
onClick={handleClickShowPassword}
onMouseDown={handleMouseDownPassword}
edge="end"
size="large"
>
{showPassword ? <Visibility /> : <VisibilityOff />}
</IconButton>
</InputAdornment>
}
}}
>
{groupOptions.map((option) => {
return (
<MenuItem key={option} value={option}>
{option}
</MenuItem>
);
})}
</Select>
{touched.group && errors.group && (
<FormHelperText error id="helper-tex-channel-group-label">
{errors.group}
</FormHelperText>
aria-describedby="helper-text-channel-password-label"
/>
{touched.password && errors.password && (
<FormHelperText error id="helper-tex-channel-password-label">
{errors.password}
</FormHelperText>
)}
</FormControl>
</>
)}
<DialogActions>
<Button onClick={onCancel}>取消</Button>
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
提交
</Button>
</DialogActions>
</form>
)}
</Formik>
</DialogContent>
</Dialog>
{values.is_edit && (
<>
<FormControl fullWidth error={Boolean(touched.quota && errors.quota)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-quota-label">额度</InputLabel>
<OutlinedInput
id="channel-quota-label"
label="额度"
type="number"
value={values.quota}
name="quota"
endAdornment={<InputAdornment position="end">{renderQuotaWithPrompt(values.quota)}</InputAdornment>}
onBlur={handleBlur}
onChange={handleChange}
aria-describedby="helper-text-channel-quota-label"
disabled={values.unlimited_quota}
/>
{touched.quota && errors.quota && (
<FormHelperText error id="helper-tex-channel-quota-label">
{errors.quota}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.group && errors.group)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-group-label">分组</InputLabel>
<Select
id="channel-group-label"
label="分组"
value={values.group}
name="group"
onBlur={handleBlur}
onChange={(e) => {
handleChange(e);
if (e.target.value === 'default') {
setFieldValue('expiration_date', null);
}
}}
MenuProps={{
PaperProps: {
style: {
maxHeight: 200
}
}
}}
>
{groupOptions.map((option) => {
return (
<MenuItem key={option} value={option}>
{option}
</MenuItem>
);
})}
</Select>
{touched.group && errors.group && (
<FormHelperText error id="helper-tex-channel-group-label">
{errors.group}
</FormHelperText>
)}
</FormControl>
</>
)}
{values.group !== 'default' && (
<FormControl fullWidth error={Boolean(touched.expiration_date && errors.expiration_date)} sx={{ ...theme.typography.otherInput }}>
<TextField
id="channel-expiration_date-label"
label="到期时间"
type="date" // 修改为 date
value={values.expiration_date}
name="expiration_date"
onBlur={handleBlur}
onChange={handleChange}
InputLabelProps={{
shrink: true
}}
inputProps={{
max: '9999-12-31' // 设置最大日期
}}
aria-describedby="helper-text-channel-expiration_date-label"
disabled={values.expiration_date === -1}
/>
{touched.expiration_date && errors.expiration_date && (
<FormHelperText error id="helper-tex-channel-expiration_date-label">
{errors.expiration_date}
</FormHelperText>
)}
<FormControlLabel
control={
<Switch
checked={values.expiration_date === -1}
onChange={(e) => setFieldValue('expiration_date', e.target.checked ? -1 : '')}
name="permanent"
color="primary"
/>
}
label="永不过期"
/>
</FormControl>
)}
<DialogActions>
<Button onClick={onCancel}>取消</Button>
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
提交
</Button>
</DialogActions>
</form>
)}
</Formik>
</DialogContent>
</Dialog>
);
};
@@ -285,4 +360,4 @@ EditModal.propTypes = {
userId: PropTypes.number,
onCancel: PropTypes.func,
onOk: PropTypes.func
};
};

View File

@@ -52,8 +52,6 @@ function renderBalance(type, balance) {
return <span>¥{balance.toFixed(2)}</span>;
case 13: // AIGC2D
return <span>{renderNumber(balance)}</span>;
case 44: // SiliconFlow
return <span>¥{balance.toFixed(2)}</span>;
default:
return <span>不支持</span>;
}

View File

@@ -10,14 +10,12 @@ const COPY_OPTIONS = [
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
{ key: 'ama', text: 'BotGem', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
];
const OPEN_LINK_OPTIONS = [
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
{ key: 'ama', text: 'BotGem', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
];
function renderTimestamp(timestamp) {
@@ -116,9 +114,6 @@ const TokensTable = () => {
case 'next':
url = nextUrl;
break;
case 'lobechat':
url = nextLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}"/v1"}}}`;
break;
default:
url = `sk-${key}`;
}
@@ -158,11 +153,7 @@ const TokensTable = () => {
case 'opencat':
url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`;
break;
case 'lobechat':
url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}"/v1"}}}`;
break;
default:
url = defaultUrl;
}