mirror of
https://github.com/songquanpeng/one-api.git
synced 2026-01-13 10:25:58 +08:00
Compare commits
30 Commits
de5faa145f
...
v0.6.10-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b2dc2c733 | ||
|
|
a3d7df7f89 | ||
|
|
c368232f50 | ||
|
|
cbfc983dc3 | ||
|
|
8ec092ba44 | ||
|
|
b0b88a79ff | ||
|
|
7e51b04221 | ||
|
|
f75a17f8eb | ||
|
|
6f13a3bb3c | ||
|
|
f092eed1db | ||
|
|
629378691b | ||
|
|
3716e1b0e6 | ||
|
|
a4d6e7a886 | ||
|
|
cb772e5d06 | ||
|
|
e32cb0b844 | ||
|
|
fdd7bf41c0 | ||
|
|
29389ed44f | ||
|
|
88acc5a614 | ||
|
|
a21681096a | ||
|
|
32f90a79a8 | ||
|
|
99c8c77504 | ||
|
|
649ecbf29c | ||
|
|
3a27c90910 | ||
|
|
cba82404ae | ||
|
|
c9ac670ba1 | ||
|
|
15f815c23c | ||
|
|
89b63ca96f | ||
|
|
8cc54489b9 | ||
|
|
58bf60805e | ||
|
|
6714cf96d6 |
@@ -90,6 +90,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
+ [x] [together.ai](https://www.together.ai/)
|
||||
+ [x] [novita.ai](https://www.novita.ai/)
|
||||
+ [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud)
|
||||
+ [x] [xAI](https://x.ai/)
|
||||
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
|
||||
3. 支持通过**负载均衡**的方式访问多个渠道。
|
||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
||||
|
||||
@@ -35,6 +35,7 @@ var PasswordLoginEnabled = true
|
||||
var PasswordRegisterEnabled = true
|
||||
var EmailVerificationEnabled = false
|
||||
var GitHubOAuthEnabled = false
|
||||
var OidcEnabled = false
|
||||
var WeChatAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
var RegisterEnabled = true
|
||||
@@ -70,6 +71,13 @@ var GitHubClientSecret = ""
|
||||
var LarkClientId = ""
|
||||
var LarkClientSecret = ""
|
||||
|
||||
var OidcClientId = ""
|
||||
var OidcClientSecret = ""
|
||||
var OidcWellKnown = ""
|
||||
var OidcAuthorizationEndpoint = ""
|
||||
var OidcTokenEndpoint = ""
|
||||
var OidcUserinfoEndpoint = ""
|
||||
|
||||
var WeChatServerAddress = ""
|
||||
var WeChatServerToken = ""
|
||||
var WeChatAccountQRCodeImageURL = ""
|
||||
@@ -99,7 +107,6 @@ var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
|
||||
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
|
||||
var TimerFrequency = env.Int("TIMER_FREQUENCY", 24) // unit is hour
|
||||
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
|
||||
|
||||
@@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
err = c.ShouldBind(&v)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Reset request body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -137,3 +137,23 @@ func String2Int(str string) int {
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func Float64PtrMax(p *float64, maxValue float64) *float64 {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
if *p > maxValue {
|
||||
return &maxValue
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func Float64PtrMin(p *float64, minValue float64) *float64 {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
if *p < minValue {
|
||||
return &minValue
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
225
controller/auth/oidc.go
Normal file
225
controller/auth/oidc.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/controller"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OidcResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type OidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
values := map[string]string{
|
||||
"client_id": config.OidcClientId,
|
||||
"client_secret": config.OidcClientSecret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
|
||||
}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oidcResponse OidcResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
var oidcUser OidcUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidcUser, nil
|
||||
}
|
||||
|
||||
func OidcAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
OidcBind(c)
|
||||
return
|
||||
}
|
||||
if !config.OidcEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
err := user.FillUserByOidcId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if config.RegisterEnabled {
|
||||
user.Email = oidcUser.Email
|
||||
if oidcUser.PreferredUsername != "" {
|
||||
user.Username = oidcUser.PreferredUsername
|
||||
} else {
|
||||
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if oidcUser.Name != "" {
|
||||
user.DisplayName = oidcUser.Name
|
||||
} else {
|
||||
user.DisplayName = "OIDC User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != model.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.SetupLogin(&user, c)
|
||||
}
|
||||
|
||||
func OidcBind(c *gin.Context) {
|
||||
if !config.OidcEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 OIDC 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) {
|
||||
if config.DisplayTokenStatEnabled {
|
||||
tokenId := c.GetInt(ctxkey.TokenId)
|
||||
token, err = model.GetTokenById(tokenId)
|
||||
expiredTime = token.ExpiredTime
|
||||
remainQuota = token.RemainQuota
|
||||
usedQuota = token.UsedQuota
|
||||
if err == nil {
|
||||
expiredTime = token.ExpiredTime
|
||||
remainQuota = token.RemainQuota
|
||||
usedQuota = token.UsedQuota
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
remainQuota, err = model.GetUserQuota(userId)
|
||||
|
||||
@@ -81,6 +81,26 @@ type APGC2DGPTUsageResponse struct {
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
}
|
||||
|
||||
type SiliconFlowUsageResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status bool `json:"status"`
|
||||
Data struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Image string `json:"image"`
|
||||
Email string `json:"email"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
Balance string `json:"balance"`
|
||||
Status string `json:"status"`
|
||||
Introduction string `json:"introduction"`
|
||||
Role string `json:"role"`
|
||||
ChargeBalance string `json:"chargeBalance"`
|
||||
TotalBalance string `json:"totalBalance"`
|
||||
Category string `json:"category"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// GetAuthHeader get auth header
|
||||
func GetAuthHeader(token string) http.Header {
|
||||
h := http.Header{}
|
||||
@@ -203,6 +223,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
||||
return response.TotalAvailable, nil
|
||||
}
|
||||
|
||||
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.siliconflow.cn/v1/user/info"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := SiliconFlowUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if response.Code != 20000 {
|
||||
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
|
||||
}
|
||||
balance, err := strconv.ParseFloat(response.Data.Balance, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := channeltype.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
@@ -227,6 +269,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
return updateChannelAPI2GPTBalance(channel)
|
||||
case channeltype.AIGC2D:
|
||||
return updateChannelAIGC2DBalance(channel)
|
||||
case channeltype.SiliconFlow:
|
||||
return updateChannelSiliconFlowBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
}
|
||||
|
||||
@@ -76,9 +76,9 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
|
||||
if len(modelNames) > 0 {
|
||||
modelName = modelNames[0]
|
||||
}
|
||||
if modelMap != nil && modelMap[modelName] != "" {
|
||||
modelName = modelMap[modelName]
|
||||
}
|
||||
}
|
||||
if modelMap != nil && modelMap[modelName] != "" {
|
||||
modelName = modelMap[modelName]
|
||||
}
|
||||
meta.OriginModelName, meta.ActualModelName = request.Model, modelName
|
||||
request.Model = modelName
|
||||
|
||||
@@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) {
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": config.EmailVerificationEnabled,
|
||||
"github_oauth": config.GitHubOAuthEnabled,
|
||||
"github_client_id": config.GitHubClientId,
|
||||
"lark_client_id": config.LarkClientId,
|
||||
"system_name": config.SystemName,
|
||||
"logo": config.Logo,
|
||||
"footer_html": config.Footer,
|
||||
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": config.WeChatAuthEnabled,
|
||||
"server_address": config.ServerAddress,
|
||||
"turnstile_check": config.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": config.TurnstileSiteKey,
|
||||
"top_up_link": config.TopUpLink,
|
||||
"chat_link": config.ChatLink,
|
||||
"quota_per_unit": config.QuotaPerUnit,
|
||||
"display_in_currency": config.DisplayInCurrencyEnabled,
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": config.EmailVerificationEnabled,
|
||||
"github_oauth": config.GitHubOAuthEnabled,
|
||||
"github_client_id": config.GitHubClientId,
|
||||
"lark_client_id": config.LarkClientId,
|
||||
"system_name": config.SystemName,
|
||||
"logo": config.Logo,
|
||||
"footer_html": config.Footer,
|
||||
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": config.WeChatAuthEnabled,
|
||||
"server_address": config.ServerAddress,
|
||||
"turnstile_check": config.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": config.TurnstileSiteKey,
|
||||
"top_up_link": config.TopUpLink,
|
||||
"chat_link": config.ChatLink,
|
||||
"quota_per_unit": config.QuotaPerUnit,
|
||||
"display_in_currency": config.DisplayInCurrencyEnabled,
|
||||
"oidc": config.OidcEnabled,
|
||||
"oidc_client_id": config.OidcClientId,
|
||||
"oidc_well_known": config.OidcWellKnown,
|
||||
"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint,
|
||||
"oidc_token_endpoint": config.OidcTokenEndpoint,
|
||||
"oidc_userinfo_endpoint": config.OidcUserinfoEndpoint,
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
3
go.mod
3
go.mod
@@ -55,7 +55,6 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-co-op/gocron v1.37.0 // indirect
|
||||
github.com/go-logr/logr v1.4.1 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
@@ -88,7 +87,6 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||
github.com/smarty/assertions v1.15.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
@@ -98,7 +96,6 @@ require (
|
||||
go.opentelemetry.io/otel v1.24.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.24.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.24.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/net v0.26.0 // indirect
|
||||
golang.org/x/oauth2 v0.21.0 // indirect
|
||||
|
||||
22
go.sum
22
go.sum
@@ -31,7 +31,6 @@ github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
@@ -68,8 +67,6 @@ github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0Nglqm
|
||||
github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw=
|
||||
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
|
||||
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
|
||||
github.com/go-co-op/gocron v1.37.0 h1:ZYDJGtQ4OMhTLKOKMIch+/CY70Brbb1dGdooLEhh7b0=
|
||||
github.com/go-co-op/gocron v1.37.0/go.mod h1:3L/n6BkO7ABj+TrfSVXLRzsP26zmikL4ISkLQ0O8iNY=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
|
||||
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
@@ -119,7 +116,6 @@ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
|
||||
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
|
||||
@@ -160,12 +156,7 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
@@ -186,7 +177,6 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
|
||||
@@ -194,12 +184,7 @@ github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYde
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
|
||||
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
|
||||
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
|
||||
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
|
||||
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
|
||||
@@ -213,7 +198,6 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
@@ -233,8 +217,6 @@ go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGX
|
||||
go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco=
|
||||
go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
|
||||
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
|
||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
@@ -284,7 +266,6 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/api v0.187.0 h1:Mxs7VATVC2v7CY+7Xwm4ndkX71hpElcvx0D1Ji/p1eo=
|
||||
google.golang.org/api v0.187.0/go.mod h1:KIHlTc4x7N7gKKuVsdmfBXN13yEEWXWFURWY6SBp2gk=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
@@ -315,10 +296,7 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
3
main.go
3
main.go
@@ -73,8 +73,6 @@ func main() {
|
||||
go model.SyncOptions(config.SyncFrequency)
|
||||
go model.SyncChannelCache(config.SyncFrequency)
|
||||
}
|
||||
go model.ScheduleCheckAndDowngrade()
|
||||
|
||||
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
||||
if err != nil {
|
||||
@@ -114,5 +112,4 @@ func main() {
|
||||
if err != nil {
|
||||
logger.FatalLog("failed to start HTTP server: " + err.Error())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
type ModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
Model string `json:"model" form:"model"`
|
||||
}
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
|
||||
13
model/log.go
13
model/log.go
@@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
@@ -152,7 +153,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
||||
}
|
||||
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
|
||||
tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
|
||||
ifnull := "ifnull"
|
||||
if common.UsingPostgreSQL {
|
||||
ifnull = "COALESCE"
|
||||
}
|
||||
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
|
||||
if username != "" {
|
||||
tx = tx.Where("username = ?", username)
|
||||
}
|
||||
@@ -176,7 +181,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
}
|
||||
|
||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
||||
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
|
||||
ifnull := "ifnull"
|
||||
if common.UsingPostgreSQL {
|
||||
ifnull = "COALESCE"
|
||||
}
|
||||
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
|
||||
if username != "" {
|
||||
tx = tx.Where("username = ?", username)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/go-co-op/gocron"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -30,6 +28,7 @@ func InitOptionMap() {
|
||||
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
|
||||
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
|
||||
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
|
||||
config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
|
||||
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
|
||||
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
|
||||
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
|
||||
@@ -80,21 +79,6 @@ func InitOptionMap() {
|
||||
loadOptionsFromDatabase()
|
||||
}
|
||||
|
||||
func ScheduleCheckAndDowngrade() {
|
||||
s := gocron.NewScheduler(time.UTC)
|
||||
|
||||
// 设置每天0点执行
|
||||
_, err := s.Every(1).Day().At("00:00").Do(checkAndDowngradeUsers)
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("创建调度任务失败: %v", err)
|
||||
}
|
||||
|
||||
// 开始调度
|
||||
s.StartBlocking()
|
||||
log.Printf("开始调度")
|
||||
}
|
||||
|
||||
func loadOptionsFromDatabase() {
|
||||
options, _ := AllOption()
|
||||
for _, option := range options {
|
||||
@@ -147,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
config.EmailVerificationEnabled = boolValue
|
||||
case "GitHubOAuthEnabled":
|
||||
config.GitHubOAuthEnabled = boolValue
|
||||
case "OidcEnabled":
|
||||
config.OidcEnabled = boolValue
|
||||
case "WeChatAuthEnabled":
|
||||
config.WeChatAuthEnabled = boolValue
|
||||
case "TurnstileCheckEnabled":
|
||||
@@ -193,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
config.LarkClientId = value
|
||||
case "LarkClientSecret":
|
||||
config.LarkClientSecret = value
|
||||
case "OidcClientId":
|
||||
config.OidcClientId = value
|
||||
case "OidcClientSecret":
|
||||
config.OidcClientSecret = value
|
||||
case "OidcWellKnown":
|
||||
config.OidcWellKnown = value
|
||||
case "OidcAuthorizationEndpoint":
|
||||
config.OidcAuthorizationEndpoint = value
|
||||
case "OidcTokenEndpoint":
|
||||
config.OidcTokenEndpoint = value
|
||||
case "OidcUserinfoEndpoint":
|
||||
config.OidcUserinfoEndpoint = value
|
||||
case "Footer":
|
||||
config.Footer = value
|
||||
case "SystemName":
|
||||
|
||||
@@ -30,7 +30,7 @@ type Token struct {
|
||||
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
|
||||
Models *string `json:"models" gorm:"default:''"` // allowed models
|
||||
Models *string `json:"models" gorm:"type:text"` // allowed models
|
||||
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
|
||||
}
|
||||
|
||||
@@ -121,30 +121,40 @@ func GetTokenById(id int) (*Token, error) {
|
||||
return &token, err
|
||||
}
|
||||
|
||||
func (token *Token) Insert() error {
|
||||
func (t *Token) Insert() error {
|
||||
var err error
|
||||
err = DB.Create(token).Error
|
||||
err = DB.Create(t).Error
|
||||
return err
|
||||
}
|
||||
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (token *Token) Update() error {
|
||||
func (t *Token) Update() error {
|
||||
var err error
|
||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
|
||||
err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (token *Token) SelectUpdate() error {
|
||||
func (t *Token) SelectUpdate() error {
|
||||
// This can update zero values
|
||||
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
|
||||
return DB.Model(t).Select("accessed_time", "status").Updates(t).Error
|
||||
}
|
||||
|
||||
func (token *Token) Delete() error {
|
||||
func (t *Token) Delete() error {
|
||||
var err error
|
||||
err = DB.Delete(token).Error
|
||||
err = DB.Delete(t).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Token) GetModels() string {
|
||||
if t == nil {
|
||||
return ""
|
||||
}
|
||||
if t.Models == nil {
|
||||
return ""
|
||||
}
|
||||
return *t.Models
|
||||
}
|
||||
|
||||
func DeleteTokenById(id int, userId int) (err error) {
|
||||
// Why we need userId here? In case user want to delete other's token.
|
||||
if id == 0 || userId == 0 {
|
||||
@@ -254,14 +264,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
||||
|
||||
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
|
||||
token, err := GetTokenById(tokenId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if quota > 0 {
|
||||
err = DecreaseUserQuota(token.UserId, quota)
|
||||
} else {
|
||||
err = IncreaseUserQuota(token.UserId, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !token.UnlimitedQuota {
|
||||
if quota > 0 {
|
||||
err = DecreaseTokenQuota(tokenId, quota)
|
||||
|
||||
@@ -10,9 +10,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -41,6 +39,7 @@ type User struct {
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
|
||||
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||
Quota int64 `json:"quota" gorm:"bigint;default:0"`
|
||||
@@ -49,7 +48,6 @@ type User struct {
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
|
||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||
ExpirationDate int64 `json:"expiration_date" gorm:"bigint;default:0;column:expiration_date"` // Expiration date of the user's subscription or account.
|
||||
}
|
||||
|
||||
func GetMaxUserId() int {
|
||||
@@ -213,25 +211,6 @@ func (user *User) ValidateAndFill() (err error) {
|
||||
if !okay || user.Status != UserStatusEnabled {
|
||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||
}
|
||||
// 校验用户是不是非default,如果是非default,判断到期时间如果过期了降级为default
|
||||
if !(user.ExpirationDate > 0 && user.Username == "root") {
|
||||
// 将时间戳转换为 time.Time 类型
|
||||
expirationTime := time.Unix(user.ExpirationDate, 0)
|
||||
// 获取当前时间
|
||||
currentTime := time.Now()
|
||||
|
||||
// 比较当前时间和到期时间
|
||||
if expirationTime.Before(currentTime) {
|
||||
// 降级为 default
|
||||
user.Group = "default"
|
||||
err := DB.Model(user).Updates(user).Error
|
||||
if err != nil {
|
||||
fmt.Printf("用户: %s, 降级为 default 时发生错误: %v\n", user.Username, err)
|
||||
return err
|
||||
}
|
||||
fmt.Printf("用户: %s, 特权组过期降为 default\n", user.Username)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -267,6 +246,14 @@ func (user *User) FillUserByLarkId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByOidcId() error {
|
||||
if user.OidcId == "" {
|
||||
return errors.New("oidc id 为空!")
|
||||
}
|
||||
DB.Where(User{OidcId: user.OidcId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByWeChatId() error {
|
||||
if user.WeChatId == "" {
|
||||
return errors.New("WeChat id 为空!")
|
||||
@@ -299,6 +286,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsOidcIdAlreadyTaken(oidcId string) bool {
|
||||
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsUsernameAlreadyTaken(username string) bool {
|
||||
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
@@ -457,48 +448,3 @@ func GetUsernameById(id int) (username string) {
|
||||
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
|
||||
return username
|
||||
}
|
||||
|
||||
func checkAndDowngradeUsers() {
|
||||
// 获取昨天的时间戳
|
||||
yesterdayTimestamp := time.Now().AddDate(0, 0, -1).Unix()
|
||||
|
||||
// 获取需要降级的用户ID列表
|
||||
var userList []int
|
||||
query := DB.Model(&User{}).
|
||||
Where("`Group` != ?", "default").
|
||||
Where("`username` != ?", "root").
|
||||
Where("`expiration_date` > 0").
|
||||
Where("`expiration_date` <= ?", yesterdayTimestamp).
|
||||
Select("id").
|
||||
Find(&userList)
|
||||
|
||||
// 处理查询错误
|
||||
if query.Error != nil {
|
||||
log.Printf("查询用户列表失败: %v", query.Error)
|
||||
return
|
||||
}
|
||||
|
||||
// 如果没有用户需要降级,直接返回
|
||||
if len(userList) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 批量降级用户
|
||||
updateQuery := DB.Model(&User{}).Where("id IN ?", userList).Update("Group", "default")
|
||||
|
||||
// 处理更新错误
|
||||
if updateQuery.Error != nil {
|
||||
log.Printf("批量更新用户分组失败: %v", updateQuery.Error)
|
||||
return
|
||||
}
|
||||
|
||||
// 删除已过期用户的Redis缓存
|
||||
if common.RedisEnabled {
|
||||
for _, userId := range userList {
|
||||
err := common.RedisSet(fmt.Sprintf("user_group:%d", userId), "default", time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
log.Printf("更新用户: %d, 权益缓存失败, Error: %v", userId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
||||
@@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
|
||||
return true
|
||||
}
|
||||
switch err.Type {
|
||||
case "insufficient_quota":
|
||||
return true
|
||||
// https://docs.anthropic.com/claude/reference/errors
|
||||
case "authentication_error":
|
||||
return true
|
||||
case "permission_error":
|
||||
return true
|
||||
case "forbidden":
|
||||
case "insufficient_quota", "authentication_error", "permission_error", "forbidden":
|
||||
return true
|
||||
}
|
||||
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
|
||||
return true
|
||||
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
|
||||
return true
|
||||
}
|
||||
//if strings.Contains(err.Message, "quota") {
|
||||
// return true
|
||||
//}
|
||||
if strings.Contains(err.Message, "credit") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(err.Message, "balance") {
|
||||
|
||||
lowerMessage := strings.ToLower(err.Message)
|
||||
if strings.Contains(lowerMessage, "your access was terminated") ||
|
||||
strings.Contains(lowerMessage, "violation of our policies") ||
|
||||
strings.Contains(lowerMessage, "your credit balance is too low") ||
|
||||
strings.Contains(lowerMessage, "organization has been disabled") ||
|
||||
strings.Contains(lowerMessage, "credit") ||
|
||||
strings.Contains(lowerMessage, "balance") ||
|
||||
strings.Contains(lowerMessage, "permission denied") ||
|
||||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
|
||||
strings.Contains(lowerMessage, "已欠费") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -3,6 +3,7 @@ package ali
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -35,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
enableSearch = true
|
||||
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
||||
}
|
||||
if request.TopP >= 1 {
|
||||
request.TopP = 0.9999
|
||||
}
|
||||
request.TopP = helper.Float64PtrMax(request.TopP, 0.9999)
|
||||
return &ChatRequest{
|
||||
Model: aliModel,
|
||||
Input: Input{
|
||||
@@ -59,7 +58,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
|
||||
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Model: request.Model,
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
@@ -102,8 +101,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
requestModel := c.GetString(ctxkey.RequestModel)
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
fullTextResponse.Model = requestModel
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -16,13 +16,13 @@ type Input struct {
|
||||
}
|
||||
|
||||
type Parameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
ResultFormat string `json:"result_format,omitempty"`
|
||||
Tools []model.Tool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
@@ -3,7 +3,11 @@ package anthropic
|
||||
var ModelList = []string{
|
||||
"claude-instant-1.2", "claude-2.0", "claude-2.1",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-5-haiku-20241022",
|
||||
}
|
||||
|
||||
@@ -48,8 +48,8 @@ type Request struct {
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
|
||||
@@ -29,10 +29,13 @@ var AwsModelIDMap = map[string]string{
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
"claude-2.1": "anthropic.claude-v2:1",
|
||||
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) (string, error) {
|
||||
|
||||
@@ -11,8 +11,8 @@ type Request struct {
|
||||
Messages []anthropic.Message `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []anthropic.Tool `json:"tools,omitempty"`
|
||||
|
||||
@@ -4,10 +4,10 @@ package aws
|
||||
//
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
type Request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
MaxGenLen int `json:"max_gen_len,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxGenLen int `json:"max_gen_len,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
}
|
||||
|
||||
// Response is the response from AWS Llama3
|
||||
|
||||
@@ -35,9 +35,9 @@ type Message struct {
|
||||
|
||||
type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
PenaltyScore *float64 `json:"penalty_score,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
DisableSearch bool `json:"disable_search,omitempty"`
|
||||
|
||||
@@ -9,5 +9,5 @@ type Request struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
||||
K: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
FrequencyPenalty: textRequest.FrequencyPenalty,
|
||||
PresencePenalty: textRequest.FrequencyPenalty,
|
||||
PresencePenalty: textRequest.PresencePenalty,
|
||||
Seed: int(textRequest.Seed),
|
||||
}
|
||||
if cohereRequest.Model == "" {
|
||||
|
||||
@@ -10,15 +10,15 @@ type Request struct {
|
||||
PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
|
||||
Connectors []Connector `json:"connectors,omitempty"`
|
||||
Documents []Document `json:"documents,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3
|
||||
Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxInputTokens int `json:"max_input_tokens,omitempty"`
|
||||
K int `json:"k,omitempty"` // 默认值为0
|
||||
P float64 `json:"p,omitempty"` // 默认值为0.75
|
||||
P *float64 `json:"p,omitempty"` // 默认值为0.75
|
||||
Seed int `json:"seed,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolResults []ToolResult `json:"tool_results,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/render"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
@@ -28,6 +29,11 @@ const (
|
||||
VisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
var mimeTypeMap = map[string]string{
|
||||
"json_object": "application/json",
|
||||
"text": "text/plain",
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
geminiRequest := ChatRequest{
|
||||
@@ -56,6 +62,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
},
|
||||
}
|
||||
if textRequest.ResponseFormat != nil {
|
||||
if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok {
|
||||
geminiRequest.GenerationConfig.ResponseMimeType = mimeType
|
||||
}
|
||||
if textRequest.ResponseFormat.JsonSchema != nil {
|
||||
geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema
|
||||
geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"]
|
||||
}
|
||||
}
|
||||
if textRequest.Tools != nil {
|
||||
functions := make([]model.Function, 0, len(textRequest.Tools))
|
||||
for _, tool := range textRequest.Tools {
|
||||
|
||||
@@ -65,10 +65,12 @@ type ChatTools struct {
|
||||
}
|
||||
|
||||
type ChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
@@ -4,14 +4,24 @@ package groq
|
||||
|
||||
var ModelList = []string{
|
||||
"gemma-7b-it",
|
||||
"mixtral-8x7b-32768",
|
||||
"llama3-8b-8192",
|
||||
"llama3-70b-8192",
|
||||
"gemma2-9b-it",
|
||||
"llama-3.1-405b-reasoning",
|
||||
"llama-3.1-70b-versatile",
|
||||
"llama-3.1-8b-instant",
|
||||
"llama-3.2-11b-text-preview",
|
||||
"llama-3.2-11b-vision-preview",
|
||||
"llama-3.2-1b-preview",
|
||||
"llama-3.2-3b-preview",
|
||||
"llama-3.2-11b-vision-preview",
|
||||
"llama-3.2-90b-text-preview",
|
||||
"llama-3.2-90b-vision-preview",
|
||||
"llama-guard-3-8b",
|
||||
"llama3-70b-8192",
|
||||
"llama3-8b-8192",
|
||||
"llama3-groq-70b-8192-tool-use-preview",
|
||||
"llama3-groq-8b-8192-tool-use-preview",
|
||||
"llava-v1.5-7b-4096-preview",
|
||||
"mixtral-8x7b-32768",
|
||||
"distil-whisper-large-v3-en",
|
||||
"whisper-large-v3",
|
||||
"whisper-large-v3-turbo",
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package ollama
|
||||
|
||||
type Options struct {
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
|
||||
@@ -75,6 +75,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if request.Stream {
|
||||
// always return usage in stream mode
|
||||
if request.StreamOptions == nil {
|
||||
request.StreamOptions = &model.StreamOptions{}
|
||||
}
|
||||
request.StreamOptions.IncludeUsage = true
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,9 +11,10 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/novita"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/xai"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
)
|
||||
|
||||
@@ -32,6 +33,7 @@ var CompatibleChannels = []int{
|
||||
channeltype.TogetherAI,
|
||||
channeltype.Novita,
|
||||
channeltype.SiliconFlow,
|
||||
channeltype.XAI,
|
||||
}
|
||||
|
||||
func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
||||
@@ -64,6 +66,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
|
||||
return "novita", novita.ModelList
|
||||
case channeltype.SiliconFlow:
|
||||
return "siliconflow", siliconflow.ModelList
|
||||
case channeltype.XAI:
|
||||
return "xai", xai.ModelList
|
||||
default:
|
||||
return "openai", ModelList
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ var ModelList = []string{
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||
"gpt-4o", "gpt-4o-2024-05-13",
|
||||
"gpt-4o-2024-08-06",
|
||||
"chatgpt-4o-latest",
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"gpt-4-vision-preview",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
|
||||
@@ -55,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
||||
render.StringData(c, data) // if error happened, pass the data to client
|
||||
continue // just ignore the error
|
||||
}
|
||||
if len(streamResponse.Choices) == 0 {
|
||||
// but for empty choice, we should not pass it to client, this is for azure
|
||||
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
||||
// but for empty choice and no usage, we should not pass it to client, this is for azure
|
||||
continue // just ignore empty choice
|
||||
}
|
||||
render.StringData(c, data)
|
||||
|
||||
@@ -19,11 +19,11 @@ type Prompt struct {
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Prompt Prompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
Prompt Prompt `json:"prompt"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
package stepfun
|
||||
|
||||
var ModelList = []string{
|
||||
"step-1-8k",
|
||||
"step-1-32k",
|
||||
"step-1-128k",
|
||||
"step-1-256k",
|
||||
"step-1-flash",
|
||||
"step-2-16k",
|
||||
"step-1v-8k",
|
||||
"step-1v-32k",
|
||||
"step-1-200k",
|
||||
"step-1x-medium",
|
||||
}
|
||||
|
||||
@@ -5,4 +5,5 @@ var ModelList = []string{
|
||||
"hunyuan-standard",
|
||||
"hunyuan-standard-256K",
|
||||
"hunyuan-pro",
|
||||
"hunyuan-vision",
|
||||
}
|
||||
|
||||
@@ -39,8 +39,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
|
||||
Model: &request.Model,
|
||||
Stream: &request.Stream,
|
||||
Messages: messages,
|
||||
TopP: &request.TopP,
|
||||
Temperature: &request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Temperature: request.Temperature,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,12 @@ import (
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229",
|
||||
"claude-3-haiku@20240307",
|
||||
"claude-3-sonnet@20240229",
|
||||
"claude-3-opus@20240229",
|
||||
"claude-3-5-sonnet@20240620",
|
||||
"claude-3-5-sonnet-v2@20241022",
|
||||
"claude-3-5-haiku@20241022",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
|
||||
@@ -11,8 +11,8 @@ type Request struct {
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools []anthropic.Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
|
||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002",
|
||||
}
|
||||
|
||||
type Adaptor struct {
|
||||
|
||||
5
relay/adaptor/xai/constants.go
Normal file
5
relay/adaptor/xai/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package xai
|
||||
|
||||
var ModelList = []string{
|
||||
"grok-beta",
|
||||
}
|
||||
@@ -5,6 +5,8 @@ var ModelList = []string{
|
||||
"SparkDesk-v1.1",
|
||||
"SparkDesk-v2.1",
|
||||
"SparkDesk-v3.1",
|
||||
"SparkDesk-v3.1-128K",
|
||||
"SparkDesk-v3.5",
|
||||
"SparkDesk-v3.5-32K",
|
||||
"SparkDesk-v4.0",
|
||||
}
|
||||
|
||||
@@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
|
||||
}
|
||||
|
||||
func parseAPIVersionByModelName(modelName string) string {
|
||||
parts := strings.Split(modelName, "-")
|
||||
if len(parts) == 2 {
|
||||
return parts[1]
|
||||
index := strings.IndexAny(modelName, "-")
|
||||
if index != -1 {
|
||||
return modelName[index+1:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -283,13 +283,17 @@ func parseAPIVersionByModelName(modelName string) string {
|
||||
func apiVersion2domain(apiVersion string) string {
|
||||
switch apiVersion {
|
||||
case "v1.1":
|
||||
return "general"
|
||||
return "lite"
|
||||
case "v2.1":
|
||||
return "generalv2"
|
||||
case "v3.1":
|
||||
return "generalv3"
|
||||
case "v3.1-128K":
|
||||
return "pro-128k"
|
||||
case "v3.5":
|
||||
return "generalv3.5"
|
||||
case "v3.5-32K":
|
||||
return "max-32k"
|
||||
case "v4.0":
|
||||
return "4.0Ultra"
|
||||
}
|
||||
@@ -297,7 +301,17 @@ func apiVersion2domain(apiVersion string) string {
|
||||
}
|
||||
|
||||
func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
|
||||
var authUrl string
|
||||
domain := apiVersion2domain(apiVersion)
|
||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||
switch apiVersion {
|
||||
case "v3.1-128K":
|
||||
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret)
|
||||
break
|
||||
case "v3.5-32K":
|
||||
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret)
|
||||
break
|
||||
default:
|
||||
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||
}
|
||||
return domain, authUrl
|
||||
}
|
||||
|
||||
@@ -19,11 +19,11 @@ type ChatRequest struct {
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
@@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request)
|
||||
return baiduEmbeddingRequest, err
|
||||
default:
|
||||
// TopP (0.0, 1.0)
|
||||
request.TopP = math.Min(0.99, request.TopP)
|
||||
request.TopP = math.Max(0.01, request.TopP)
|
||||
// TopP [0.0, 1.0]
|
||||
request.TopP = helper.Float64PtrMax(request.TopP, 1)
|
||||
request.TopP = helper.Float64PtrMin(request.TopP, 0)
|
||||
|
||||
// Temperature (0.0, 1.0)
|
||||
request.Temperature = math.Min(0.99, request.Temperature)
|
||||
request.Temperature = math.Max(0.01, request.Temperature)
|
||||
// Temperature [0.0, 1.0]
|
||||
request.Temperature = helper.Float64PtrMax(request.Temperature, 1)
|
||||
request.Temperature = helper.Float64PtrMin(request.Temperature, 0)
|
||||
a.SetVersionByModeName(request.Model)
|
||||
if a.APIVersion == "v4" {
|
||||
return request, nil
|
||||
|
||||
@@ -12,8 +12,8 @@ type Message struct {
|
||||
|
||||
type Request struct {
|
||||
Prompt []Message `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Incremental bool `json:"incremental,omitempty"`
|
||||
}
|
||||
|
||||
@@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{
|
||||
"720x1280": 1,
|
||||
"1280x720": 1,
|
||||
},
|
||||
"step-1x-medium": {
|
||||
"256x256": 1,
|
||||
"512x512": 1,
|
||||
"768x768": 1,
|
||||
"1024x1024": 1,
|
||||
"1280x800": 1,
|
||||
"800x1280": 1,
|
||||
},
|
||||
}
|
||||
|
||||
var ImageGenerationAmounts = map[string][2]int{
|
||||
@@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{
|
||||
"ali-stable-diffusion-v1.5": {1, 4}, // Ali
|
||||
"wanx-v1": {1, 4}, // Ali
|
||||
"cogview-3": {1, 1},
|
||||
"step-1x-medium": {1, 1},
|
||||
}
|
||||
|
||||
var ImagePromptLengthLimitations = map[string]int{
|
||||
@@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{
|
||||
"ali-stable-diffusion-v1.5": 4000,
|
||||
"wanx-v1": 4000,
|
||||
"cogview-3": 833,
|
||||
"step-1x-medium": 4000,
|
||||
}
|
||||
|
||||
var ImageOriginModelName = map[string]string{
|
||||
|
||||
@@ -34,7 +34,9 @@ var ModelRatio = map[string]float64{
|
||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||
"gpt-4o": 2.5, // $0.005 / 1K tokens
|
||||
"chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens
|
||||
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
|
||||
"gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens
|
||||
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
|
||||
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
@@ -77,8 +79,10 @@ var ModelRatio = map[string]float64{
|
||||
"claude-2.0": 8.0 / 1000 * USD,
|
||||
"claude-2.1": 8.0 / 1000 * USD,
|
||||
"claude-3-haiku-20240307": 0.25 / 1000 * USD,
|
||||
"claude-3-5-haiku-20241022": 1.0 / 1000 * USD,
|
||||
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
|
||||
"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
|
||||
"claude-3-5-sonnet-20241022": 3.0 / 1000 * USD,
|
||||
"claude-3-opus-20240229": 15.0 / 1000 * USD,
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
|
||||
"ERNIE-4.0-8K": 0.120 * RMB,
|
||||
@@ -126,7 +130,9 @@ var ModelRatio = map[string]float64{
|
||||
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
|
||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
@@ -158,23 +164,34 @@ var ModelRatio = map[string]float64{
|
||||
"mistral-embed": 0.1 / 1000 * USD,
|
||||
// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed
|
||||
"gemma-7b-it": 0.07 / 1000000 * USD,
|
||||
"mixtral-8x7b-32768": 0.24 / 1000000 * USD,
|
||||
"llama3-8b-8192": 0.05 / 1000000 * USD,
|
||||
"llama3-70b-8192": 0.59 / 1000000 * USD,
|
||||
"gemma2-9b-it": 0.20 / 1000000 * USD,
|
||||
"llama-3.1-405b-reasoning": 0.89 / 1000000 * USD,
|
||||
"llama-3.1-70b-versatile": 0.59 / 1000000 * USD,
|
||||
"llama-3.1-8b-instant": 0.05 / 1000000 * USD,
|
||||
"llama-3.2-11b-text-preview": 0.05 / 1000000 * USD,
|
||||
"llama-3.2-11b-vision-preview": 0.05 / 1000000 * USD,
|
||||
"llama-3.2-1b-preview": 0.05 / 1000000 * USD,
|
||||
"llama-3.2-3b-preview": 0.05 / 1000000 * USD,
|
||||
"llama-3.2-90b-text-preview": 0.59 / 1000000 * USD,
|
||||
"llama-guard-3-8b": 0.05 / 1000000 * USD,
|
||||
"llama3-70b-8192": 0.59 / 1000000 * USD,
|
||||
"llama3-8b-8192": 0.05 / 1000000 * USD,
|
||||
"llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD,
|
||||
"llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD,
|
||||
"mixtral-8x7b-32768": 0.24 / 1000000 * USD,
|
||||
|
||||
// https://platform.lingyiwanwu.com/docs#-计费单元
|
||||
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
|
||||
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
|
||||
"yi-vl-plus": 6.0 / 1000 * RMB,
|
||||
// stepfun todo
|
||||
"step-1v-32k": 0.024 * RMB,
|
||||
"step-1-32k": 0.024 * RMB,
|
||||
"step-1-200k": 0.15 * RMB,
|
||||
// https://platform.stepfun.com/docs/pricing/details
|
||||
"step-1-8k": 0.005 / 1000 * RMB,
|
||||
"step-1-32k": 0.015 / 1000 * RMB,
|
||||
"step-1-128k": 0.040 / 1000 * RMB,
|
||||
"step-1-256k": 0.095 / 1000 * RMB,
|
||||
"step-1-flash": 0.001 / 1000 * RMB,
|
||||
"step-2-16k": 0.038 / 1000 * RMB,
|
||||
"step-1v-8k": 0.005 / 1000 * RMB,
|
||||
"step-1v-32k": 0.015 / 1000 * RMB,
|
||||
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
|
||||
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
|
||||
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
|
||||
@@ -192,6 +209,8 @@ var ModelRatio = map[string]float64{
|
||||
"deepl-zh": 25.0 / 1000 * USD,
|
||||
"deepl-en": 25.0 / 1000 * USD,
|
||||
"deepl-ja": 25.0 / 1000 * USD,
|
||||
// https://console.x.ai/
|
||||
"grok-beta": 5.0 / 1000 * USD,
|
||||
}
|
||||
|
||||
var CompletionRatio = map[string]float64{
|
||||
@@ -200,8 +219,10 @@ var CompletionRatio = map[string]float64{
|
||||
"llama3-70b-8192(33)": 0.0035 / 0.00265,
|
||||
}
|
||||
|
||||
var DefaultModelRatio map[string]float64
|
||||
var DefaultCompletionRatio map[string]float64
|
||||
var (
|
||||
DefaultModelRatio map[string]float64
|
||||
DefaultCompletionRatio map[string]float64
|
||||
)
|
||||
|
||||
func init() {
|
||||
DefaultModelRatio = make(map[string]float64)
|
||||
@@ -313,7 +334,7 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
||||
return 4.0 / 3.0
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") {
|
||||
if strings.HasPrefix(name, "gpt-4o-mini") {
|
||||
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" {
|
||||
return 4
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4-turbo") ||
|
||||
@@ -323,6 +344,9 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
||||
}
|
||||
return 2
|
||||
}
|
||||
if name == "chatgpt-4o-latest" {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "claude-3") {
|
||||
return 5
|
||||
}
|
||||
@@ -351,6 +375,8 @@ func GetCompletionRatio(name string, channelType int) float64 {
|
||||
return 3
|
||||
case "command-r-plus":
|
||||
return 5
|
||||
case "grok-beta":
|
||||
return 3
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
@@ -46,5 +46,6 @@ const (
|
||||
VertextAI
|
||||
Proxy
|
||||
SiliconFlow
|
||||
XAI
|
||||
Dummy
|
||||
)
|
||||
|
||||
@@ -45,7 +45,8 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.novita.ai/v3/openai", // 41
|
||||
"", // 42
|
||||
"", // 43
|
||||
"https://api.siliconflow.cn", // 44
|
||||
"https://api.siliconflow.cn", // 44
|
||||
"https://api.x.ai", // 45
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
ContentTypeInputAudio = "input_audio"
|
||||
)
|
||||
|
||||
@@ -1,35 +1,70 @@
|
||||
package model
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema *JSONSchema `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type JSONSchema struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema map[string]interface{} `json:"schema,omitempty"`
|
||||
Strict *bool `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
type Audio struct {
|
||||
Voice string `json:"voice,omitempty"`
|
||||
Format string `json:"format,omitempty"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
Metadata any `json:"metadata,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias any `json:"logit_bias,omitempty"`
|
||||
Logprobs *bool `json:"logprobs,omitempty"`
|
||||
TopLogprobs *int `json:"top_logprobs,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Modalities []string `json:"modalities,omitempty"`
|
||||
Prediction any `json:"prediction,omitempty"`
|
||||
Audio *Audio `json:"audio,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
ServiceTier *string `json:"service_tier,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
FunctionCall any `json:"function_call,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
// https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
Input any `json:"input,omitempty"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Quality *string `json:"quality,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Style *string `json:"style,omitempty"`
|
||||
// Others
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
|
||||
@@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth)
|
||||
apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)
|
||||
|
||||
@@ -11,12 +11,14 @@ import EditToken from '../pages/Token/EditToken';
|
||||
const COPY_OPTIONS = [
|
||||
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
|
||||
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
|
||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' }
|
||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
||||
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
|
||||
];
|
||||
|
||||
const OPEN_LINK_OPTIONS = [
|
||||
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
|
||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' }
|
||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
||||
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' }
|
||||
];
|
||||
|
||||
function renderTimestamp(timestamp) {
|
||||
@@ -60,7 +62,12 @@ const TokensTable = () => {
|
||||
onOpenLink('next-mj');
|
||||
}
|
||||
},
|
||||
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' }
|
||||
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' },
|
||||
{
|
||||
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
|
||||
onOpenLink('lobechat');
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
const columns = [
|
||||
@@ -177,6 +184,11 @@ const TokensTable = () => {
|
||||
node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => {
|
||||
onOpenLink('opencat', record.key);
|
||||
}
|
||||
},
|
||||
{
|
||||
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
|
||||
onOpenLink('lobechat');
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -382,6 +394,9 @@ const TokensTable = () => {
|
||||
case 'next-mj':
|
||||
url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||
break;
|
||||
case 'lobechat':
|
||||
url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`;
|
||||
break;
|
||||
default:
|
||||
if (!chatLink) {
|
||||
showError('管理员未设置聊天链接');
|
||||
|
||||
@@ -30,6 +30,7 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 42, text: 'VertexAI', value: 42, color: 'blue' },
|
||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||
|
||||
@@ -63,7 +63,7 @@ const EditChannel = (props) => {
|
||||
let localModels = [];
|
||||
switch (value) {
|
||||
case 14:
|
||||
localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"];
|
||||
localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-haiku-20241022", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022"];
|
||||
break;
|
||||
case 11:
|
||||
localModels = ['PaLM-2'];
|
||||
@@ -78,7 +78,7 @@ const EditChannel = (props) => {
|
||||
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
||||
break;
|
||||
case 18:
|
||||
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0'];
|
||||
localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v3.5-32K', 'SparkDesk-v4.0'];
|
||||
break;
|
||||
case 19:
|
||||
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
"@tabler/icons-react": "^2.44.0",
|
||||
"apexcharts": "3.35.3",
|
||||
"axios": "^0.27.2",
|
||||
"date-fns": "^3.6.0",
|
||||
"dayjs": "^1.11.10",
|
||||
"formik": "^2.2.9",
|
||||
"framer-motion": "^6.3.16",
|
||||
@@ -28,7 +27,6 @@
|
||||
"prop-types": "^15.8.1",
|
||||
"react": "^18.2.0",
|
||||
"react-apexcharts": "1.4.0",
|
||||
"react-datepicker": "^7.3.0",
|
||||
"react-device-detect": "^2.2.2",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-perfect-scrollbar": "^1.5.8",
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 5.4 KiB After Width: | Height: | Size: 4.3 KiB |
7
web/berry/src/assets/images/icons/oidc.svg
Normal file
7
web/berry/src/assets/images/icons/oidc.svg
Normal file
@@ -0,0 +1,7 @@
|
||||
<svg t="1723135116886" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg"
|
||||
p-id="10969" width="200" height="200">
|
||||
<path d="M512 960C265 960 64 759 64 512S265 64 512 64s448 201 448 448-201 448-448 448z m0-882.6c-239.7 0-434.6 195-434.6 434.6s195 434.6 434.6 434.6 434.6-195 434.6-434.6S751.7 77.4 512 77.4z"
|
||||
p-id="10970" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="60"></path>
|
||||
<path d="M197.7 512c0-78.3 31.6-98.8 87.2-98.8 56.2 0 87.2 20.5 87.2 98.8s-31 98.8-87.2 98.8c-55.7 0-87.2-20.5-87.2-98.8z m130.4 0c0-46.8-7.8-64.5-43.2-64.5-35.2 0-42.9 17.7-42.9 64.5 0 47.1 7.8 63.7 42.9 63.7 35.4 0 43.2-16.6 43.2-63.7zM409.7 415.9h42.1V608h-42.1V415.9zM653.9 512c0 74.2-37.1 96.1-93.6 96.1h-65.9V415.9h65.9c56.5 0 93.6 16.1 93.6 96.1z m-43.5 0c0-49.3-17.7-60.6-52.3-60.6h-21.6v120.7h21.6c35.4 0 52.3-13.3 52.3-60.1zM686.5 512c0-74.2 36.3-98.8 92.7-98.8 18.3 0 33.2 2.2 44.8 6.4v36.3c-11.9-4.2-26-6.6-42.1-6.6-34.6 0-49.8 15.5-49.8 62.6 0 50.1 15.2 62.6 49.3 62.6 15.8 0 30.2-2.2 44.8-7.5v36c-11.3 4.7-28.5 8-46.8 8-56.1-0.2-92.9-18.7-92.9-99z"
|
||||
p-id="10971" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="20"></path>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.2 KiB |
@@ -22,7 +22,12 @@ const config = {
|
||||
turnstile_site_key: '',
|
||||
version: '',
|
||||
wechat_login: false,
|
||||
wechat_qrcode: ''
|
||||
wechat_qrcode: '',
|
||||
oidc: false,
|
||||
oidc_client_id: '',
|
||||
oidc_authorization_endpoint: '',
|
||||
oidc_token_endpoint: '',
|
||||
oidc_userinfo_endpoint: '',
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -179,6 +179,12 @@ export const CHANNEL_OPTIONS = {
|
||||
value: 44,
|
||||
color: 'primary'
|
||||
},
|
||||
45: {
|
||||
key: 45,
|
||||
text: 'xAI',
|
||||
value: 45,
|
||||
color: 'primary'
|
||||
},
|
||||
41: {
|
||||
key: 41,
|
||||
text: 'Novita',
|
||||
|
||||
@@ -70,6 +70,28 @@ const useLogin = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const oidcLogin = async (code, state) => {
|
||||
try {
|
||||
const res = await API.get(`/api/oauth/oidc?code=${code}&state=${state}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
if (message === 'bind') {
|
||||
showSuccess('绑定成功!');
|
||||
navigate('/panel');
|
||||
} else {
|
||||
dispatch({ type: LOGIN, payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
showSuccess('登录成功!');
|
||||
navigate('/panel');
|
||||
}
|
||||
}
|
||||
return { success, message };
|
||||
} catch (err) {
|
||||
// 请求失败,设置错误信息
|
||||
return { success: false, message: '' };
|
||||
}
|
||||
}
|
||||
|
||||
const wechatLogin = async (code) => {
|
||||
try {
|
||||
const res = await API.get(`/api/oauth/wechat?code=${code}`);
|
||||
@@ -94,7 +116,7 @@ const useLogin = () => {
|
||||
navigate('/');
|
||||
};
|
||||
|
||||
return { login, logout, githubLogin, wechatLogin, larkLogin };
|
||||
return { login, logout, githubLogin, wechatLogin, larkLogin,oidcLogin };
|
||||
};
|
||||
|
||||
export default useLogin;
|
||||
|
||||
@@ -9,6 +9,7 @@ const AuthLogin = Loadable(lazy(() => import('views/Authentication/Auth/Login'))
|
||||
const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register')));
|
||||
const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth')));
|
||||
const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth')));
|
||||
const OidcOAuth = Loadable(lazy(() => import('views/Authentication/Auth/OidcOAuth')));
|
||||
const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword')));
|
||||
const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword')));
|
||||
const Home = Loadable(lazy(() => import('views/Home')));
|
||||
@@ -53,6 +54,10 @@ const OtherRoutes = {
|
||||
path: '/oauth/lark',
|
||||
element: <LarkOAuth />
|
||||
},
|
||||
{
|
||||
path: 'oauth/oidc',
|
||||
element: <OidcOAuth />
|
||||
},
|
||||
{
|
||||
path: '/404',
|
||||
element: <NotFoundView />
|
||||
|
||||
@@ -98,6 +98,21 @@ export async function onLarkOAuthClicked(lark_client_id) {
|
||||
window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`);
|
||||
}
|
||||
|
||||
export async function onOidcClicked(auth_url, client_id, openInNewTab = false) {
|
||||
const state = await getOAuthState();
|
||||
if (!state) return;
|
||||
const redirect_uri = `${window.location.origin}/oauth/oidc`;
|
||||
const response_type = "code";
|
||||
const scope = "openid profile email";
|
||||
const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`;
|
||||
if (openInNewTab) {
|
||||
window.open(url);
|
||||
} else
|
||||
{
|
||||
window.location.href = url;
|
||||
}
|
||||
}
|
||||
|
||||
export function isAdmin() {
|
||||
let user = localStorage.getItem('user');
|
||||
if (!user) return false;
|
||||
|
||||
94
web/berry/src/views/Authentication/Auth/OidcOAuth.js
Normal file
94
web/berry/src/views/Authentication/Auth/OidcOAuth.js
Normal file
@@ -0,0 +1,94 @@
|
||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { showError } from 'utils/common';
|
||||
import useLogin from 'hooks/useLogin';
|
||||
|
||||
// material-ui
|
||||
import { useTheme } from '@mui/material/styles';
|
||||
import { Grid, Stack, Typography, useMediaQuery, CircularProgress } from '@mui/material';
|
||||
|
||||
// project imports
|
||||
import AuthWrapper from '../AuthWrapper';
|
||||
import AuthCardWrapper from '../AuthCardWrapper';
|
||||
import Logo from 'ui-component/Logo';
|
||||
|
||||
// assets
|
||||
|
||||
// ================================|| AUTH3 - LOGIN ||================================ //
|
||||
|
||||
const OidcOAuth = () => {
|
||||
const theme = useTheme();
|
||||
const matchDownSM = useMediaQuery(theme.breakpoints.down('md'));
|
||||
|
||||
const [searchParams] = useSearchParams();
|
||||
const [prompt, setPrompt] = useState('处理中...');
|
||||
const { oidcLogin } = useLogin();
|
||||
|
||||
let navigate = useNavigate();
|
||||
|
||||
const sendCode = async (code, state, count) => {
|
||||
const { success, message } = await oidcLogin(code, state);
|
||||
if (!success) {
|
||||
if (message) {
|
||||
showError(message);
|
||||
}
|
||||
if (count === 0) {
|
||||
setPrompt(`操作失败,重定向至登录界面中...`);
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
navigate('/login');
|
||||
return;
|
||||
}
|
||||
count++;
|
||||
setPrompt(`出现错误,第 ${count} 次重试中...`);
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
await sendCode(code, state, count);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let code = searchParams.get('code');
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<AuthWrapper>
|
||||
<Grid container direction="column" justifyContent="flex-end">
|
||||
<Grid item xs={12}>
|
||||
<Grid container justifyContent="center" alignItems="center" sx={{ minHeight: 'calc(100vh - 136px)' }}>
|
||||
<Grid item sx={{ m: { xs: 1, sm: 3 }, mb: 0 }}>
|
||||
<AuthCardWrapper>
|
||||
<Grid container spacing={2} alignItems="center" justifyContent="center">
|
||||
<Grid item sx={{ mb: 3 }}>
|
||||
<Link to="#">
|
||||
<Logo />
|
||||
</Link>
|
||||
</Grid>
|
||||
<Grid item xs={12}>
|
||||
<Grid container direction={matchDownSM ? 'column-reverse' : 'row'} alignItems="center" justifyContent="center">
|
||||
<Grid item>
|
||||
<Stack alignItems="center" justifyContent="center" spacing={1}>
|
||||
<Typography color={theme.palette.primary.main} gutterBottom variant={matchDownSM ? 'h3' : 'h2'}>
|
||||
OIDC 登录
|
||||
</Typography>
|
||||
</Stack>
|
||||
</Grid>
|
||||
</Grid>
|
||||
</Grid>
|
||||
<Grid item xs={12} container direction="column" justifyContent="center" alignItems="center" style={{ height: '200px' }}>
|
||||
<CircularProgress />
|
||||
<Typography variant="h3" paddingTop={'20px'}>
|
||||
{prompt}
|
||||
</Typography>
|
||||
</Grid>
|
||||
</Grid>
|
||||
</AuthCardWrapper>
|
||||
</Grid>
|
||||
</Grid>
|
||||
</Grid>
|
||||
</Grid>
|
||||
</AuthWrapper>
|
||||
);
|
||||
};
|
||||
|
||||
export default OidcOAuth;
|
||||
@@ -36,7 +36,8 @@ import VisibilityOff from '@mui/icons-material/VisibilityOff';
|
||||
import Github from 'assets/images/icons/github.svg';
|
||||
import Wechat from 'assets/images/icons/wechat.svg';
|
||||
import Lark from 'assets/images/icons/lark.svg';
|
||||
import { onGitHubOAuthClicked, onLarkOAuthClicked } from 'utils/common';
|
||||
import OIDC from 'assets/images/icons/oidc.svg';
|
||||
import { onGitHubOAuthClicked, onLarkOAuthClicked, onOidcClicked } from 'utils/common';
|
||||
|
||||
// ============================|| FIREBASE - LOGIN ||============================ //
|
||||
|
||||
@@ -50,7 +51,7 @@ const LoginForm = ({ ...others }) => {
|
||||
// const [checked, setChecked] = useState(true);
|
||||
|
||||
let tripartiteLogin = false;
|
||||
if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id) {
|
||||
if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id || siteInfo.oidc) {
|
||||
tripartiteLogin = true;
|
||||
}
|
||||
|
||||
@@ -145,6 +146,29 @@ const LoginForm = ({ ...others }) => {
|
||||
</AnimateButton>
|
||||
</Grid>
|
||||
)}
|
||||
{siteInfo.oidc && (
|
||||
<Grid item xs={12}>
|
||||
<AnimateButton>
|
||||
<Button
|
||||
disableElevation
|
||||
fullWidth
|
||||
onClick={() => onOidcClicked(siteInfo.oidc_authorization_endpoint,siteInfo.oidc_client_id)}
|
||||
size="large"
|
||||
variant="outlined"
|
||||
sx={{
|
||||
color: 'grey.700',
|
||||
backgroundColor: theme.palette.grey[50],
|
||||
borderColor: theme.palette.grey[100]
|
||||
}}
|
||||
>
|
||||
<Box sx={{ mr: { xs: 1, sm: 2, width: 20 }, display: 'flex', alignItems: 'center' }}>
|
||||
<img src={OIDC} alt="Lark" width={25} height={25} style={{ marginRight: matchDownSM ? 8 : 16 }} />
|
||||
</Box>
|
||||
使用 OIDC 登录
|
||||
</Button>
|
||||
</AnimateButton>
|
||||
</Grid>
|
||||
)}
|
||||
<Grid item xs={12}>
|
||||
<Box
|
||||
sx={{
|
||||
|
||||
@@ -268,6 +268,8 @@ function renderBalance(type, balance) {
|
||||
return <span>¥{balance.toFixed(2)}</span>;
|
||||
case 13: // AIGC2D
|
||||
return <span>{renderNumber(balance)}</span>;
|
||||
case 44: // SiliconFlow
|
||||
return <span>¥{balance.toFixed(2)}</span>;
|
||||
default:
|
||||
return <span>不支持</span>;
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ const typeConfig = {
|
||||
other: '版本号'
|
||||
},
|
||||
input: {
|
||||
models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']
|
||||
models: ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v3.5-32K', 'SparkDesk-v4.0']
|
||||
},
|
||||
prompt: {
|
||||
key: '按照如下格式输入:APPID|APISecret|APIKey',
|
||||
@@ -223,6 +223,9 @@ const typeConfig = {
|
||||
},
|
||||
modelGroup: 'anthropic'
|
||||
},
|
||||
45: {
|
||||
modelGroup: 'xai'
|
||||
},
|
||||
};
|
||||
|
||||
export { defaultConfig, typeConfig };
|
||||
|
||||
@@ -20,7 +20,7 @@ import SubCard from 'ui-component/cards/SubCard';
|
||||
import { IconBrandWechat, IconBrandGithub, IconMail } from '@tabler/icons-react';
|
||||
import Label from 'ui-component/Label';
|
||||
import { API } from 'utils/api';
|
||||
import { showError, showSuccess } from 'utils/common';
|
||||
import { onOidcClicked, showError, showSuccess } from 'utils/common';
|
||||
import { onGitHubOAuthClicked, onLarkOAuthClicked, copy } from 'utils/common';
|
||||
import * as Yup from 'yup';
|
||||
import WechatModal from 'views/Authentication/AuthForms/WechatModal';
|
||||
@@ -28,6 +28,7 @@ import { useSelector } from 'react-redux';
|
||||
import EmailModal from './component/EmailModal';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import { ReactComponent as Lark } from 'assets/images/icons/lark.svg';
|
||||
import { ReactComponent as OIDC } from 'assets/images/icons/oidc.svg';
|
||||
|
||||
const validationSchema = Yup.object().shape({
|
||||
username: Yup.string().required('用户名 不能为空').min(3, '用户名 不能小于 3 个字符'),
|
||||
@@ -123,6 +124,15 @@ export default function Profile() {
|
||||
loadUser().then();
|
||||
}, [status]);
|
||||
|
||||
function getOidcId(){
|
||||
if (!inputs.oidc_id) return '';
|
||||
let oidc_id = inputs.oidc_id;
|
||||
if (inputs.oidc_id.length > 8) {
|
||||
oidc_id = inputs.oidc_id.slice(0, 6) + '...' + inputs.oidc_id.slice(-6);
|
||||
}
|
||||
return oidc_id;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<UserCard>
|
||||
@@ -141,6 +151,9 @@ export default function Profile() {
|
||||
<Label variant="ghost" color={inputs.lark_id ? 'primary' : 'default'}>
|
||||
<SvgIcon component={Lark} inheritViewBox="0 0 24 24" /> {inputs.lark_id || '未绑定'}
|
||||
</Label>
|
||||
<Label variant="ghost" color={inputs.oidc_id ? 'primary' : 'default'}>
|
||||
<SvgIcon component={OIDC} inheritViewBox="0 0 24 24" /> {getOidcId() || '未绑定'}
|
||||
</Label>
|
||||
</Stack>
|
||||
<SubCard title="个人信息">
|
||||
<Grid container spacing={2}>
|
||||
@@ -216,6 +229,13 @@ export default function Profile() {
|
||||
</Button>
|
||||
</Grid>
|
||||
)}
|
||||
{status.oidc && !inputs.oidc_id && (
|
||||
<Grid xs={12} md={4}>
|
||||
<Button variant="contained" onClick={() => onOidcClicked(status.oidc_authorization_endpoint,status.oidc_client_id,true)}>
|
||||
绑定 OIDC 账号
|
||||
</Button>
|
||||
</Grid>
|
||||
)}
|
||||
<Grid xs={12} md={4}>
|
||||
<Button
|
||||
variant="contained"
|
||||
|
||||
@@ -33,6 +33,13 @@ const SystemSetting = () => {
|
||||
GitHubClientSecret: '',
|
||||
LarkClientId: '',
|
||||
LarkClientSecret: '',
|
||||
OidcEnabled: '',
|
||||
OidcWellKnown: '',
|
||||
OidcClientId: '',
|
||||
OidcClientSecret: '',
|
||||
OidcAuthorizationEndpoint: '',
|
||||
OidcTokenEndpoint: '',
|
||||
OidcUserinfoEndpoint: '',
|
||||
Notice: '',
|
||||
SMTPServer: '',
|
||||
SMTPPort: '',
|
||||
@@ -94,6 +101,7 @@ const SystemSetting = () => {
|
||||
case 'TurnstileCheckEnabled':
|
||||
case 'EmailDomainRestrictionEnabled':
|
||||
case 'RegisterEnabled':
|
||||
case 'OidcEnabled':
|
||||
value = inputs[key] === 'true' ? 'false' : 'true';
|
||||
break;
|
||||
default:
|
||||
@@ -142,8 +150,15 @@ const SystemSetting = () => {
|
||||
name === 'MessagePusherAddress' ||
|
||||
name === 'MessagePusherToken' ||
|
||||
name === 'LarkClientId' ||
|
||||
name === 'LarkClientSecret'
|
||||
) {
|
||||
name === 'LarkClientSecret' ||
|
||||
name === 'OidcClientId' ||
|
||||
name === 'OidcClientSecret' ||
|
||||
name === 'OidcWellKnown' ||
|
||||
name === 'OidcAuthorizationEndpoint' ||
|
||||
name === 'OidcTokenEndpoint' ||
|
||||
name === 'OidcUserinfoEndpoint'
|
||||
)
|
||||
{
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
} else {
|
||||
await updateOption(name, value);
|
||||
@@ -225,6 +240,43 @@ const SystemSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const submitOidc = async () => {
|
||||
if (inputs.OidcWellKnown !== '') {
|
||||
if (!inputs.OidcWellKnown.startsWith('http://') && !inputs.OidcWellKnown.startsWith('https://')) {
|
||||
showError('Well-Known URL 必须以 http:// 或 https:// 开头');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const res = await API.get(inputs.OidcWellKnown);
|
||||
inputs.OidcAuthorizationEndpoint = res.data['authorization_endpoint'];
|
||||
inputs.OidcTokenEndpoint = res.data['token_endpoint'];
|
||||
inputs.OidcUserinfoEndpoint = res.data['userinfo_endpoint'];
|
||||
showSuccess('获取 OIDC 配置成功!');
|
||||
} catch (err) {
|
||||
showError("获取 OIDC 配置失败,请检查网络状况和 Well-Known URL 是否正确");
|
||||
}
|
||||
}
|
||||
|
||||
if (originInputs['OidcWellKnown'] !== inputs.OidcWellKnown) {
|
||||
await updateOption('OidcWellKnown', inputs.OidcWellKnown);
|
||||
}
|
||||
if (originInputs['OidcClientId'] !== inputs.OidcClientId) {
|
||||
await updateOption('OidcClientId', inputs.OidcClientId);
|
||||
}
|
||||
if (originInputs['OidcClientSecret'] !== inputs.OidcClientSecret && inputs.OidcClientSecret !== '') {
|
||||
await updateOption('OidcClientSecret', inputs.OidcClientSecret);
|
||||
}
|
||||
if (originInputs['OidcAuthorizationEndpoint'] !== inputs.OidcAuthorizationEndpoint) {
|
||||
await updateOption('OidcAuthorizationEndpoint', inputs.OidcAuthorizationEndpoint);
|
||||
}
|
||||
if (originInputs['OidcTokenEndpoint'] !== inputs.OidcTokenEndpoint) {
|
||||
await updateOption('OidcTokenEndpoint', inputs.OidcTokenEndpoint);
|
||||
}
|
||||
if (originInputs['OidcUserinfoEndpoint'] !== inputs.OidcUserinfoEndpoint) {
|
||||
await updateOption('OidcUserinfoEndpoint', inputs.OidcUserinfoEndpoint);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Stack spacing={2}>
|
||||
@@ -291,6 +343,12 @@ const SystemSetting = () => {
|
||||
control={<Checkbox checked={inputs.GitHubOAuthEnabled === 'true'} onChange={handleInputChange} name="GitHubOAuthEnabled" />}
|
||||
/>
|
||||
</Grid>
|
||||
<Grid xs={12} md={3}>
|
||||
<FormControlLabel
|
||||
label="允许通过 OIDC 登录 & 注册"
|
||||
control={<Checkbox checked={inputs.OidcEnabled === 'true'} onChange={handleInputChange} name="OidcEnabled" />}
|
||||
/>
|
||||
</Grid>
|
||||
<Grid xs={12} md={3}>
|
||||
<FormControlLabel
|
||||
label="允许通过微信登录 & 注册"
|
||||
@@ -616,6 +674,117 @@ const SystemSetting = () => {
|
||||
</Grid>
|
||||
</Grid>
|
||||
</SubCard>
|
||||
|
||||
<SubCard
|
||||
title="配置 OIDC"
|
||||
subTitle={
|
||||
<span>
|
||||
用以支持通过 OIDC 登录,例如 Okta、Auth0 等兼容 OIDC 协议的 IdP
|
||||
</span>
|
||||
}
|
||||
>
|
||||
<Grid container spacing={ { xs: 3, sm: 2, md: 4 } }>
|
||||
<Grid xs={ 12 } md={ 12 }>
|
||||
<Alert severity="info" sx={ { wordWrap: 'break-word' } }>
|
||||
主页链接填 <code>{ inputs.ServerAddress }</code>
|
||||
,重定向 URL 填 <code>{ `${ inputs.ServerAddress }/oauth/oidc` }</code>
|
||||
</Alert> <br />
|
||||
<Alert severity="info" sx={ { wordWrap: 'break-word' } }>
|
||||
若你的 OIDC Provider 支持 Discovery Endpoint,你可以仅填写 OIDC Well-Known URL,系统会自动获取 OIDC 配置
|
||||
</Alert>
|
||||
</Grid>
|
||||
<Grid xs={ 12 } md={ 6 }>
|
||||
<FormControl fullWidth>
|
||||
<InputLabel htmlFor="OidcClientId">Client ID</InputLabel>
|
||||
<OutlinedInput
|
||||
id="OidcClientId"
|
||||
name="OidcClientId"
|
||||
value={ inputs.OidcClientId || '' }
|
||||
onChange={ handleInputChange }
|
||||
label="Client ID"
|
||||
placeholder="输入 OIDC 的 Client ID"
|
||||
disabled={ loading }
|
||||
/>
|
||||
</FormControl>
|
||||
</Grid>
|
||||
<Grid xs={ 12 } md={ 6 }>
|
||||
<FormControl fullWidth>
|
||||
<InputLabel htmlFor="OidcClientSecret">Client Secret</InputLabel>
|
||||
<OutlinedInput
|
||||
id="OidcClientSecret"
|
||||
name="OidcClientSecret"
|
||||
value={ inputs.OidcClientSecret || '' }
|
||||
onChange={ handleInputChange }
|
||||
label="Client Secret"
|
||||
placeholder="敏感信息不会发送到前端显示"
|
||||
disabled={ loading }
|
||||
/>
|
||||
</FormControl>
|
||||
</Grid>
|
||||
<Grid xs={ 12 } md={ 6 }>
|
||||
<FormControl fullWidth>
|
||||
<InputLabel htmlFor="OidcWellKnown">Well-Known URL</InputLabel>
|
||||
<OutlinedInput
|
||||
id="OidcWellKnown"
|
||||
name="OidcWellKnown"
|
||||
value={ inputs.OidcWellKnown || '' }
|
||||
onChange={ handleInputChange }
|
||||
label="Well-Known URL"
|
||||
placeholder="请输入 OIDC 的 Well-Known URL"
|
||||
disabled={ loading }
|
||||
/>
|
||||
</FormControl>
|
||||
</Grid>
|
||||
<Grid xs={ 12 } md={ 6 }>
|
||||
<FormControl fullWidth>
|
||||
<InputLabel htmlFor="OidcAuthorizationEndpoint">Authorization Endpoint</InputLabel>
|
||||
<OutlinedInput
|
||||
id="OidcAuthorizationEndpoint"
|
||||
name="OidcAuthorizationEndpoint"
|
||||
value={ inputs.OidcAuthorizationEndpoint || '' }
|
||||
onChange={ handleInputChange }
|
||||
label="Authorization Endpoint"
|
||||
placeholder="输入 OIDC 的 Authorization Endpoint"
|
||||
disabled={ loading }
|
||||
/>
|
||||
</FormControl>
|
||||
</Grid>
|
||||
<Grid xs={ 12 } md={ 6 }>
|
||||
<FormControl fullWidth>
|
||||
<InputLabel htmlFor="OidcTokenEndpoint">Token Endpoint</InputLabel>
|
||||
<OutlinedInput
|
||||
id="OidcTokenEndpoint"
|
||||
name="OidcTokenEndpoint"
|
||||
value={ inputs.OidcTokenEndpoint || '' }
|
||||
onChange={ handleInputChange }
|
||||
label="Token Endpoint"
|
||||
placeholder="输入 OIDC 的 Token Endpoint"
|
||||
disabled={ loading }
|
||||
/>
|
||||
</FormControl>
|
||||
</Grid>
|
||||
<Grid xs={ 12 } md={ 6 }>
|
||||
<FormControl fullWidth>
|
||||
<InputLabel htmlFor="OidcUserinfoEndpoint">Userinfo Endpoint</InputLabel>
|
||||
<OutlinedInput
|
||||
id="OidcUserinfoEndpoint"
|
||||
name="OidcUserinfoEndpoint"
|
||||
value={ inputs.OidcUserinfoEndpoint || '' }
|
||||
onChange={ handleInputChange }
|
||||
label="Userinfo Endpoint"
|
||||
placeholder="输入 OIDC 的 Userinfo Endpoint"
|
||||
disabled={ loading }
|
||||
/>
|
||||
</FormControl>
|
||||
</Grid>
|
||||
<Grid xs={ 12 }>
|
||||
<Button variant="contained" onClick={ submitOidc }>
|
||||
保存 OIDC 设置
|
||||
</Button>
|
||||
</Grid>
|
||||
</Grid>
|
||||
</SubCard>
|
||||
|
||||
<SubCard
|
||||
title="配置 Message Pusher"
|
||||
subTitle={
|
||||
|
||||
@@ -32,7 +32,8 @@ const COPY_OPTIONS = [
|
||||
encode: false
|
||||
},
|
||||
{ key: 'ama', text: 'BotGem', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true },
|
||||
{ key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true }
|
||||
{ key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true },
|
||||
{ key: 'lobechat', text: 'LobeChat', url: 'https://lobehub.com/?settings={"keyVaults":{"openai":{"apiKey":"sk-{key}","baseURL":"{serverAddress}"}}}', encode: true }
|
||||
];
|
||||
|
||||
function replacePlaceholders(text, key, serverAddress) {
|
||||
|
||||
@@ -2,9 +2,7 @@ import PropTypes from 'prop-types';
|
||||
import * as Yup from 'yup';
|
||||
import { Formik } from 'formik';
|
||||
import { useTheme } from '@mui/material/styles';
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import { format } from 'date-fns';
|
||||
|
||||
import { useState, useEffect } from 'react';
|
||||
import {
|
||||
Dialog,
|
||||
DialogTitle,
|
||||
@@ -19,11 +17,7 @@ import {
|
||||
Select,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
FormHelperText,
|
||||
TextField,
|
||||
Typography,
|
||||
Switch,
|
||||
FormControlLabel
|
||||
FormHelperText
|
||||
} from '@mui/material';
|
||||
|
||||
import Visibility from '@mui/icons-material/Visibility';
|
||||
@@ -50,17 +44,6 @@ const validationSchema = Yup.object().shape({
|
||||
is: false,
|
||||
then: Yup.number().min(0, '额度 不能小于 0'),
|
||||
otherwise: Yup.number()
|
||||
}),
|
||||
expiration_date: Yup.mixed().when('group', {
|
||||
is: (group) => group !== 'default',
|
||||
then: Yup.mixed().test(
|
||||
'expiration_date-required',
|
||||
'到期时间 不能为空',
|
||||
function (value) {
|
||||
const { expiration_date } = this.parent;
|
||||
return expiration_date === -1 || !!expiration_date;
|
||||
}
|
||||
),
|
||||
})
|
||||
});
|
||||
|
||||
@@ -70,8 +53,7 @@ const originInputs = {
|
||||
display_name: '',
|
||||
password: '',
|
||||
group: 'default',
|
||||
quota: 0,
|
||||
expiration_date: null
|
||||
quota: 0
|
||||
};
|
||||
|
||||
const EditModal = ({ open, userId, onCancel, onOk }) => {
|
||||
@@ -83,12 +65,6 @@ const EditModal = ({ open, userId, onCancel, onOk }) => {
|
||||
const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
|
||||
setSubmitting(true);
|
||||
|
||||
// 将到期时间转换为 Unix 时间戳
|
||||
if (values.expiration_date && values.expiration_date !== -1) {
|
||||
const date = new Date(values.expiration_date);
|
||||
values.expiration_date = Math.floor(date.getTime() / 1000); // 转换为秒级的 Unix 时间戳
|
||||
}
|
||||
|
||||
let res;
|
||||
if (values.is_edit) {
|
||||
res = await API.put(`/api/user/`, { ...values, id: parseInt(userId) });
|
||||
@@ -119,23 +95,16 @@ const EditModal = ({ open, userId, onCancel, onOk }) => {
|
||||
event.preventDefault();
|
||||
};
|
||||
|
||||
const loadUser = useCallback(async () => {
|
||||
const loadUser = async () => {
|
||||
let res = await API.get(`/api/user/${userId}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
data.is_edit = true;
|
||||
|
||||
// 将 Unix 时间戳转换为日期字符串
|
||||
if (data.expiration_date && data.expiration_date !== -1) {
|
||||
const date = new Date(data.expiration_date * 1000); // 转换为毫秒级的时间戳
|
||||
data.expiration_date = format(date, 'yyyy-MM-dd'); // 格式化为 date 格式
|
||||
}
|
||||
|
||||
setInputs(data);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
}, [userId]);
|
||||
};
|
||||
|
||||
const fetchGroups = async () => {
|
||||
try {
|
||||
@@ -153,203 +122,159 @@ const EditModal = ({ open, userId, onCancel, onOk }) => {
|
||||
} else {
|
||||
setInputs(originInputs);
|
||||
}
|
||||
}, [userId, loadUser]);
|
||||
}, [userId]);
|
||||
|
||||
return (
|
||||
<Dialog open={open} onClose={onCancel} fullWidth maxWidth={'md'}>
|
||||
<DialogTitle sx={{ margin: '0px', fontWeight: 700, lineHeight: '1.55556', padding: '24px', fontSize: '1.125rem' }}>
|
||||
{userId ? '编辑用户' : '新建用户'}
|
||||
</DialogTitle>
|
||||
<Divider />
|
||||
<DialogContent>
|
||||
<Formik initialValues={inputs} enableReinitialize validationSchema={validationSchema} onSubmit={submit}>
|
||||
{({ errors, handleBlur, handleChange, handleSubmit, setFieldValue, touched, values, isSubmitting }) => (
|
||||
<form noValidate onSubmit={handleSubmit}>
|
||||
<FormControl fullWidth error={Boolean(touched.username && errors.username)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-username-label">用户名</InputLabel>
|
||||
<Dialog open={open} onClose={onCancel} fullWidth maxWidth={'md'}>
|
||||
<DialogTitle sx={{ margin: '0px', fontWeight: 700, lineHeight: '1.55556', padding: '24px', fontSize: '1.125rem' }}>
|
||||
{userId ? '编辑用户' : '新建用户'}
|
||||
</DialogTitle>
|
||||
<Divider />
|
||||
<DialogContent>
|
||||
<Formik initialValues={inputs} enableReinitialize validationSchema={validationSchema} onSubmit={submit}>
|
||||
{({ errors, handleBlur, handleChange, handleSubmit, touched, values, isSubmitting }) => (
|
||||
<form noValidate onSubmit={handleSubmit}>
|
||||
<FormControl fullWidth error={Boolean(touched.username && errors.username)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-username-label">用户名</InputLabel>
|
||||
<OutlinedInput
|
||||
id="channel-username-label"
|
||||
label="用户名"
|
||||
type="text"
|
||||
value={values.username}
|
||||
name="username"
|
||||
onBlur={handleBlur}
|
||||
onChange={handleChange}
|
||||
inputProps={{ autoComplete: 'username' }}
|
||||
aria-describedby="helper-text-channel-username-label"
|
||||
/>
|
||||
{touched.username && errors.username && (
|
||||
<FormHelperText error id="helper-tex-channel-username-label">
|
||||
{errors.username}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
|
||||
<FormControl fullWidth error={Boolean(touched.display_name && errors.display_name)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-display_name-label">显示名称</InputLabel>
|
||||
<OutlinedInput
|
||||
id="channel-display_name-label"
|
||||
label="显示名称"
|
||||
type="text"
|
||||
value={values.display_name}
|
||||
name="display_name"
|
||||
onBlur={handleBlur}
|
||||
onChange={handleChange}
|
||||
inputProps={{ autoComplete: 'display_name' }}
|
||||
aria-describedby="helper-text-channel-display_name-label"
|
||||
/>
|
||||
{touched.display_name && errors.display_name && (
|
||||
<FormHelperText error id="helper-tex-channel-display_name-label">
|
||||
{errors.display_name}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
|
||||
<FormControl fullWidth error={Boolean(touched.password && errors.password)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-password-label">密码</InputLabel>
|
||||
<OutlinedInput
|
||||
id="channel-password-label"
|
||||
label="密码"
|
||||
type={showPassword ? 'text' : 'password'}
|
||||
value={values.password}
|
||||
name="password"
|
||||
onBlur={handleBlur}
|
||||
onChange={handleChange}
|
||||
inputProps={{ autoComplete: 'password' }}
|
||||
endAdornment={
|
||||
<InputAdornment position="end">
|
||||
<IconButton
|
||||
aria-label="toggle password visibility"
|
||||
onClick={handleClickShowPassword}
|
||||
onMouseDown={handleMouseDownPassword}
|
||||
edge="end"
|
||||
size="large"
|
||||
>
|
||||
{showPassword ? <Visibility /> : <VisibilityOff />}
|
||||
</IconButton>
|
||||
</InputAdornment>
|
||||
}
|
||||
aria-describedby="helper-text-channel-password-label"
|
||||
/>
|
||||
{touched.password && errors.password && (
|
||||
<FormHelperText error id="helper-tex-channel-password-label">
|
||||
{errors.password}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
|
||||
{values.is_edit && (
|
||||
<>
|
||||
<FormControl fullWidth error={Boolean(touched.quota && errors.quota)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-quota-label">额度</InputLabel>
|
||||
<OutlinedInput
|
||||
id="channel-username-label"
|
||||
label="用户名"
|
||||
type="text"
|
||||
value={values.username}
|
||||
name="username"
|
||||
onBlur={handleBlur}
|
||||
onChange={handleChange}
|
||||
inputProps={{ autoComplete: 'username' }}
|
||||
aria-describedby="helper-text-channel-username-label"
|
||||
id="channel-quota-label"
|
||||
label="额度"
|
||||
type="number"
|
||||
value={values.quota}
|
||||
name="quota"
|
||||
endAdornment={<InputAdornment position="end">{renderQuotaWithPrompt(values.quota)}</InputAdornment>}
|
||||
onBlur={handleBlur}
|
||||
onChange={handleChange}
|
||||
aria-describedby="helper-text-channel-quota-label"
|
||||
disabled={values.unlimited_quota}
|
||||
/>
|
||||
{touched.username && errors.username && (
|
||||
<FormHelperText error id="helper-tex-channel-username-label">
|
||||
{errors.username}
|
||||
</FormHelperText>
|
||||
|
||||
{touched.quota && errors.quota && (
|
||||
<FormHelperText error id="helper-tex-channel-quota-label">
|
||||
{errors.quota}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
|
||||
<FormControl fullWidth error={Boolean(touched.display_name && errors.display_name)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-display_name-label">显示名称</InputLabel>
|
||||
<OutlinedInput
|
||||
id="channel-display_name-label"
|
||||
label="显示名称"
|
||||
type="text"
|
||||
value={values.display_name}
|
||||
name="display_name"
|
||||
onBlur={handleBlur}
|
||||
onChange={handleChange}
|
||||
inputProps={{ autoComplete: 'display_name' }}
|
||||
aria-describedby="helper-text-channel-display_name-label"
|
||||
/>
|
||||
{touched.display_name && errors.display_name && (
|
||||
<FormHelperText error id="helper-tex-channel-display_name-label">
|
||||
{errors.display_name}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
|
||||
<FormControl fullWidth error={Boolean(touched.password && errors.password)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-password-label">密码</InputLabel>
|
||||
<OutlinedInput
|
||||
id="channel-password-label"
|
||||
label="密码"
|
||||
type={showPassword ? 'text' : 'password'}
|
||||
value={values.password}
|
||||
name="password"
|
||||
onBlur={handleBlur}
|
||||
onChange={handleChange}
|
||||
inputProps={{ autoComplete: 'password' }}
|
||||
endAdornment={
|
||||
<InputAdornment position="end">
|
||||
<IconButton
|
||||
aria-label="toggle password visibility"
|
||||
onClick={handleClickShowPassword}
|
||||
onMouseDown={handleMouseDownPassword}
|
||||
edge="end"
|
||||
size="large"
|
||||
>
|
||||
{showPassword ? <Visibility /> : <VisibilityOff />}
|
||||
</IconButton>
|
||||
</InputAdornment>
|
||||
<FormControl fullWidth error={Boolean(touched.group && errors.group)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-group-label">分组</InputLabel>
|
||||
<Select
|
||||
id="channel-group-label"
|
||||
label="分组"
|
||||
value={values.group}
|
||||
name="group"
|
||||
onBlur={handleBlur}
|
||||
onChange={handleChange}
|
||||
MenuProps={{
|
||||
PaperProps: {
|
||||
style: {
|
||||
maxHeight: 200
|
||||
}
|
||||
}
|
||||
aria-describedby="helper-text-channel-password-label"
|
||||
/>
|
||||
{touched.password && errors.password && (
|
||||
<FormHelperText error id="helper-tex-channel-password-label">
|
||||
{errors.password}
|
||||
</FormHelperText>
|
||||
}}
|
||||
>
|
||||
{groupOptions.map((option) => {
|
||||
return (
|
||||
<MenuItem key={option} value={option}>
|
||||
{option}
|
||||
</MenuItem>
|
||||
);
|
||||
})}
|
||||
</Select>
|
||||
{touched.group && errors.group && (
|
||||
<FormHelperText error id="helper-tex-channel-group-label">
|
||||
{errors.group}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
|
||||
{values.is_edit && (
|
||||
<>
|
||||
<FormControl fullWidth error={Boolean(touched.quota && errors.quota)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-quota-label">额度</InputLabel>
|
||||
<OutlinedInput
|
||||
id="channel-quota-label"
|
||||
label="额度"
|
||||
type="number"
|
||||
value={values.quota}
|
||||
name="quota"
|
||||
endAdornment={<InputAdornment position="end">{renderQuotaWithPrompt(values.quota)}</InputAdornment>}
|
||||
onBlur={handleBlur}
|
||||
onChange={handleChange}
|
||||
aria-describedby="helper-text-channel-quota-label"
|
||||
disabled={values.unlimited_quota}
|
||||
/>
|
||||
|
||||
{touched.quota && errors.quota && (
|
||||
<FormHelperText error id="helper-tex-channel-quota-label">
|
||||
{errors.quota}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
|
||||
<FormControl fullWidth error={Boolean(touched.group && errors.group)} sx={{ ...theme.typography.otherInput }}>
|
||||
<InputLabel htmlFor="channel-group-label">分组</InputLabel>
|
||||
<Select
|
||||
id="channel-group-label"
|
||||
label="分组"
|
||||
value={values.group}
|
||||
name="group"
|
||||
onBlur={handleBlur}
|
||||
onChange={(e) => {
|
||||
handleChange(e);
|
||||
if (e.target.value === 'default') {
|
||||
setFieldValue('expiration_date', null);
|
||||
}
|
||||
}}
|
||||
MenuProps={{
|
||||
PaperProps: {
|
||||
style: {
|
||||
maxHeight: 200
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
{groupOptions.map((option) => {
|
||||
return (
|
||||
<MenuItem key={option} value={option}>
|
||||
{option}
|
||||
</MenuItem>
|
||||
);
|
||||
})}
|
||||
</Select>
|
||||
{touched.group && errors.group && (
|
||||
<FormHelperText error id="helper-tex-channel-group-label">
|
||||
{errors.group}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
</>
|
||||
)}
|
||||
|
||||
{values.group !== 'default' && (
|
||||
<FormControl fullWidth error={Boolean(touched.expiration_date && errors.expiration_date)} sx={{ ...theme.typography.otherInput }}>
|
||||
<TextField
|
||||
id="channel-expiration_date-label"
|
||||
label="到期时间"
|
||||
type="date" // 修改为 date
|
||||
value={values.expiration_date}
|
||||
name="expiration_date"
|
||||
onBlur={handleBlur}
|
||||
onChange={handleChange}
|
||||
InputLabelProps={{
|
||||
shrink: true
|
||||
}}
|
||||
inputProps={{
|
||||
max: '9999-12-31' // 设置最大日期
|
||||
}}
|
||||
aria-describedby="helper-text-channel-expiration_date-label"
|
||||
disabled={values.expiration_date === -1}
|
||||
/>
|
||||
{touched.expiration_date && errors.expiration_date && (
|
||||
<FormHelperText error id="helper-tex-channel-expiration_date-label">
|
||||
{errors.expiration_date}
|
||||
</FormHelperText>
|
||||
)}
|
||||
<FormControlLabel
|
||||
control={
|
||||
<Switch
|
||||
checked={values.expiration_date === -1}
|
||||
onChange={(e) => setFieldValue('expiration_date', e.target.checked ? -1 : '')}
|
||||
name="permanent"
|
||||
color="primary"
|
||||
/>
|
||||
}
|
||||
label="永不过期"
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<DialogActions>
|
||||
<Button onClick={onCancel}>取消</Button>
|
||||
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
|
||||
提交
|
||||
</Button>
|
||||
</DialogActions>
|
||||
</form>
|
||||
)}
|
||||
</Formik>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</>
|
||||
)}
|
||||
<DialogActions>
|
||||
<Button onClick={onCancel}>取消</Button>
|
||||
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
|
||||
提交
|
||||
</Button>
|
||||
</DialogActions>
|
||||
</form>
|
||||
)}
|
||||
</Formik>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -360,4 +285,4 @@ EditModal.propTypes = {
|
||||
userId: PropTypes.number,
|
||||
onCancel: PropTypes.func,
|
||||
onOk: PropTypes.func
|
||||
};
|
||||
};
|
||||
|
||||
@@ -52,11 +52,19 @@ function renderBalance(type, balance) {
|
||||
return <span>¥{balance.toFixed(2)}</span>;
|
||||
case 13: // AIGC2D
|
||||
return <span>{renderNumber(balance)}</span>;
|
||||
case 44: // SiliconFlow
|
||||
return <span>¥{balance.toFixed(2)}</span>;
|
||||
default:
|
||||
return <span>不支持</span>;
|
||||
}
|
||||
}
|
||||
|
||||
function isShowDetail() {
|
||||
return localStorage.getItem("show_detail") === "true";
|
||||
}
|
||||
|
||||
const promptID = "detail"
|
||||
|
||||
const ChannelsTable = () => {
|
||||
const [channels, setChannels] = useState([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
@@ -64,7 +72,8 @@ const ChannelsTable = () => {
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
const [searching, setSearching] = useState(false);
|
||||
const [updatingBalance, setUpdatingBalance] = useState(false);
|
||||
const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test"));
|
||||
const [showPrompt, setShowPrompt] = useState(shouldShowPrompt(promptID));
|
||||
const [showDetail, setShowDetail] = useState(isShowDetail());
|
||||
|
||||
const loadChannels = async (startIdx) => {
|
||||
const res = await API.get(`/api/channel/?p=${startIdx}`);
|
||||
@@ -118,6 +127,11 @@ const ChannelsTable = () => {
|
||||
await loadChannels(activePage - 1);
|
||||
};
|
||||
|
||||
const toggleShowDetail = () => {
|
||||
setShowDetail(!showDetail);
|
||||
localStorage.setItem("show_detail", (!showDetail).toString());
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
loadChannels(0)
|
||||
.then()
|
||||
@@ -362,11 +376,13 @@ const ChannelsTable = () => {
|
||||
showPrompt && (
|
||||
<Message onDismiss={() => {
|
||||
setShowPrompt(false);
|
||||
setPromptShown("channel-test");
|
||||
setPromptShown(promptID);
|
||||
}}>
|
||||
OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
|
||||
<br/>
|
||||
渠道测试仅支持 chat 模型,优先使用 gpt-3.5-turbo,如果该模型不可用则使用你所配置的模型列表中的第一个模型。
|
||||
<br/>
|
||||
点击下方详情按钮可以显示余额以及设置额外的测试模型。
|
||||
</Message>
|
||||
)
|
||||
}
|
||||
@@ -426,6 +442,7 @@ const ChannelsTable = () => {
|
||||
onClick={() => {
|
||||
sortChannel('balance');
|
||||
}}
|
||||
hidden={!showDetail}
|
||||
>
|
||||
余额
|
||||
</Table.HeaderCell>
|
||||
@@ -437,7 +454,7 @@ const ChannelsTable = () => {
|
||||
>
|
||||
优先级
|
||||
</Table.HeaderCell>
|
||||
<Table.HeaderCell>测试模型</Table.HeaderCell>
|
||||
<Table.HeaderCell hidden={!showDetail}>测试模型</Table.HeaderCell>
|
||||
<Table.HeaderCell>操作</Table.HeaderCell>
|
||||
</Table.Row>
|
||||
</Table.Header>
|
||||
@@ -465,7 +482,7 @@ const ChannelsTable = () => {
|
||||
basic
|
||||
/>
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
<Table.Cell hidden={!showDetail}>
|
||||
<Popup
|
||||
trigger={<span onClick={() => {
|
||||
updateChannelBalance(channel.id, channel.name, idx);
|
||||
@@ -492,7 +509,7 @@ const ChannelsTable = () => {
|
||||
basic
|
||||
/>
|
||||
</Table.Cell>
|
||||
<Table.Cell>
|
||||
<Table.Cell hidden={!showDetail}>
|
||||
<Dropdown
|
||||
placeholder='请选择测试模型'
|
||||
selection
|
||||
@@ -571,7 +588,7 @@ const ChannelsTable = () => {
|
||||
|
||||
<Table.Footer>
|
||||
<Table.Row>
|
||||
<Table.HeaderCell colSpan='9'>
|
||||
<Table.HeaderCell colSpan={showDetail ? "10" : "8"}>
|
||||
<Button size='small' as={Link} to='/channel/add' loading={loading}>
|
||||
添加新的渠道
|
||||
</Button>
|
||||
@@ -609,6 +626,7 @@ const ChannelsTable = () => {
|
||||
}
|
||||
/>
|
||||
<Button size='small' onClick={refresh} loading={loading}>刷新</Button>
|
||||
<Button size='small' onClick={toggleShowDetail}>{showDetail ? "隐藏详情" : "详情"}</Button>
|
||||
</Table.HeaderCell>
|
||||
</Table.Row>
|
||||
</Table.Footer>
|
||||
|
||||
@@ -10,12 +10,14 @@ const COPY_OPTIONS = [
|
||||
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
|
||||
{ key: 'ama', text: 'BotGem', value: 'ama' },
|
||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
||||
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
|
||||
];
|
||||
|
||||
const OPEN_LINK_OPTIONS = [
|
||||
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
|
||||
{ key: 'ama', text: 'BotGem', value: 'ama' },
|
||||
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
|
||||
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
|
||||
];
|
||||
|
||||
function renderTimestamp(timestamp) {
|
||||
@@ -114,6 +116,9 @@ const TokensTable = () => {
|
||||
case 'next':
|
||||
url = nextUrl;
|
||||
break;
|
||||
case 'lobechat':
|
||||
url = nextLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`;
|
||||
break;
|
||||
default:
|
||||
url = `sk-${key}`;
|
||||
}
|
||||
@@ -153,7 +158,11 @@ const TokensTable = () => {
|
||||
case 'opencat':
|
||||
url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`;
|
||||
break;
|
||||
|
||||
|
||||
case 'lobechat':
|
||||
url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`;
|
||||
break;
|
||||
|
||||
default:
|
||||
url = defaultUrl;
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ export const CHANNEL_OPTIONS = [
|
||||
{ key: 42, text: 'VertexAI', value: 42, color: 'blue' },
|
||||
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
|
||||
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
|
||||
{ key: 45, text: 'xAI', value: 45, color: 'blue' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||
|
||||
@@ -2,7 +2,7 @@ import React from 'react';
|
||||
import { Header, Segment } from 'semantic-ui-react';
|
||||
import ChannelsTable from '../../components/ChannelsTable';
|
||||
|
||||
const File = () => (
|
||||
const Channel = () => (
|
||||
<>
|
||||
<Segment>
|
||||
<Header as='h3'>管理渠道</Header>
|
||||
@@ -11,4 +11,4 @@ const File = () => (
|
||||
</>
|
||||
);
|
||||
|
||||
export default File;
|
||||
export default Channel;
|
||||
|
||||
Reference in New Issue
Block a user