Compare commits

..

26 Commits

Author SHA1 Message Date
jinjianming
67cbae4d8e Merge ade00eed1c into 99c8c77504 2024-09-21 23:25:37 +08:00
OnEvent
99c8c77504 feat: add oidc support (#1725)
Some checks are pending
CI / Unit tests (push) Waiting to run
CI / commit_lint (push) Waiting to run
* feat: add the ui for configuring the third-party standard OAuth2.0/OIDC.

- update SystemSetting.js
- add setup ui
- add configuration

* feat: add the ui for "allow the OAuth 2.0 to login"

- update SystemSetting.js

* feat: add OAuth 2.0 web ui and its process functions

- update common.js
- update AuthLogin.js
- update config.js

* fix: missing "Userinfo" endpoint configuration entry, used by OAuth clients to request user information from the IdP.

- update config.js
- update SystemSetting.js

* feat: updated the icons for Lark and OIDC to match the style of the icons for WeChat, EMail, GitHub.

- update lark.svg
- new oidc.svg

* refactor: Changing OAuth 2.0 to OIDC

* feat: add OIDC login method

* feat: Add support for OIDC login to the backend

* fix: Change the AppId and AppSecret on the Web UI to the standard usage: ClientId, ClientSecret.

* feat: Support quick configuration of OIDC through Well-Known Discovery Endpoint

* feat: Standardize terminology, add well-known configuration

- Change the AppId and AppSecret on the Server End to the standard usage: ClientId, ClientSecret.
- add Well-Known configuration to store in database, no actual use in server end but store and display in web ui only
2024-09-21 23:03:20 +08:00
TAKO
649ecbf29c feat: support new openai models (4o 0806, chatgpt-4o-latest) (#1721)
* feat: support new model gpt-4o-2024-08-06

* feat: support new model chatgpt-4o-latest
2024-09-21 23:01:19 +08:00
qinguoyi
3a27c90910 fix: getTokenById return token nil, make panic (#1728)
* fix:getTokenById return token nil, make panic

* chore: remove useless err check

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-21 23:00:29 +08:00
千寻简
cba82404ae feat: add lobechat open link options (#1741)
Co-authored-by: Star <iii9777@163.com>
2024-09-21 22:49:31 +08:00
forrestlinfeng
c9ac670ba1 feat: update stepfun models (#1740)
Co-authored-by: chenlinfeng <chenlinfeng@step.ai>
2024-09-21 22:48:46 +08:00
leavegee
15f815c23c fix: fix ali embedding model always use v1 (#1747)
* fix:ali embedding model: v2 and v3

* chore: use ctxkey.RequestModel to eliminate hardcoding

---------

Co-authored-by: xuejia <gexuejia@djbx.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-21 22:40:06 +08:00
majian
89b63ca96f feat: ResponseFormat support json_schema (#1759)
* feat: responseFormat support json_schema

* chore: rename struct name

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-09-21 22:35:24 +08:00
Ghostz
8cc54489b9 feat: update disabled channel (#1780)
* Update disabled channel

* Update manage.go

* Update manage.go

* chore: add missing space

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-09-21 22:31:53 +08:00
guogeer
58bf60805e fix: postgres use COALESCE replace null (#1793)
Co-authored-by: jinqi.guo <jinqi.guo@ubtrobot.com>
2024-09-21 22:13:31 +08:00
AJ's Life Journey
6714cf96d6 fix: Groq organization not auto-disabled when blocked (#1822) 2024-09-21 22:12:09 +08:00
longkeyy
f9774698e9 feat: synchronize with the official release of the groq model (#1677)
update groq add gemma2-9b-it llama3.1 family fixup price k/token -> m/token
2024-08-06 23:51:08 +08:00
TAKO
2af6f6a166 feat: add Cloudflare New Free Model Llama 3.1 8b (#1703) 2024-08-06 23:49:48 +08:00
MotorBottle
04bb3ef392 feat: add Max Tokens and Context Window Setting Options for Ollama Channel (#1694)
* Update main.go with max_tokens param

* Update model.go with max_tokens param

* Update model.go

* Update main.go

* Update main.go

* Adds num_ctx param for Ollama Channel

* Added num_ctx param for ollama adapter

* Added num_ctx param for ollama adapter

* Improved data process logic
2024-08-06 23:44:37 +08:00
longkeyy
b4bfa418a8 feat: update gemini model and price (#1705) 2024-08-06 23:43:33 +08:00
SLKun
e7e99e558a feat: update Ollama embedding API to latest version with multi-text embedding support (#1715) 2024-08-06 23:43:20 +08:00
Shenghang Tsai
402fcf7f79 feat: add SiliconFlow (#1717)
* Add SiliconFlow

* Update README.md

* Update README.md

* Update channel.constants.js

* Update ChannelConstants.js

* Update channel.constants.js

* Update ChannelConstants.js

* Update compatible.go

* Update README.md
2024-08-06 23:42:25 +08:00
Junyan Qin
36039e329e docs: update introduction for QChatGPT (#1707) 2024-08-06 23:33:43 +08:00
Laisky.Cai
c936198ac8 feat: add Proxy channel type and relay mode (#1678)
Add the Proxy channel type and relay mode to support proxying requests to custom upstream services.
2024-07-22 22:51:19 +08:00
TAKO
296ab013b8 feat: support gpt-4o mini (#1665)
* feat: support gpt-4o mini

* feat: fix gpt-4o mini image price
2024-07-22 22:44:08 +08:00
zijiren
5f03c856b4 feat: fast build linux/arm64 frontend (#1663)
* feat: fast build linux/arm64 frontend

* fix: dockerfile as replace to AS

* fix: trim space
2024-07-22 22:39:22 +08:00
igophper
39383e5532 fix: support embedding models for doubao (#1662)
Fixes #1594
2024-07-22 22:38:50 +08:00
jinjianmingming
ade00eed1c fix: 修复Update未能预期执行问题 2024-07-19 10:00:58 +08:00
jinjianmingming
66d31102d3 feat: 调整为定时每晚00:00
feat: 添加Redis更新缓存操作
fate: expiration_date设置默认值0
2024-07-18 22:58:04 +08:00
jinjianmingming
1aba1e4ac3 feat: 优化用户等级刷新测试,默认24h扫描一次,可以通过TIMER_FREQUENCY设置时间单位是小时 2024-07-18 10:39:21 +08:00
jinjianmingming
b86a3acd37 feat: 引入到期时间概念,当非default的分组时将有到期时间概念
fix: 修复旧用户的vip被降级,尽量不对原先用户做变动
2024-07-17 16:18:55 +08:00
67 changed files with 1587 additions and 462 deletions

View File

@@ -1,61 +0,0 @@
name: Publish Docker image (amd64)
on:
push:
tags:
- 'v*.*.*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
push_to_registries:
name: Push Docker image to multiple registries
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info
run: |
git describe --tags > VERSION
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: |
justsong/one-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images
uses: docker/build-push-action@v3
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -1,4 +1,4 @@
name: Publish Docker image (amd64, English)
name: Publish Docker image (English)
on:
push:
@@ -34,6 +34,13 @@ jobs:
- name: Translate
run: |
python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
@@ -51,6 +58,7 @@ jobs:
uses: docker/build-push-action@v3
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -1,10 +1,9 @@
name: Publish Docker image (arm64)
name: Publish Docker image
on:
push:
tags:
- 'v*.*.*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:

View File

@@ -1,4 +1,4 @@
FROM node:16 as builder
FROM --platform=$BUILDPLATFORM node:16 AS builder
WORKDIR /web
COPY ./VERSION .

View File

@@ -89,6 +89,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [DeepL](https://www.deepl.com/)
+ [x] [together.ai](https://www.together.ai/)
+ [x] [novita.ai](https://www.novita.ai/)
+ [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@@ -251,9 +252,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
#### QChatGPT - QQ机器人
项目主页https://github.com/RockChinQ/QChatGPT
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
运行期间可以通过`!model`命令查看、切换可用模型。
### 部署到第三方平台
<details>

View File

@@ -35,6 +35,7 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var OidcEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
@@ -70,6 +71,13 @@ var GitHubClientSecret = ""
var LarkClientId = ""
var LarkClientSecret = ""
var OidcClientId = ""
var OidcClientSecret = ""
var OidcWellKnown = ""
var OidcAuthorizationEndpoint = ""
var OidcTokenEndpoint = ""
var OidcUserinfoEndpoint = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
@@ -99,6 +107,7 @@ var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
var TimerFrequency = env.Int("TIMER_FREQUENCY", 24) // unit is hour
var BatchUpdateEnabled = false
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)

225
controller/auth/oidc.go Normal file
View File

@@ -0,0 +1,225 @@
package auth
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{
"client_id": config.OidcClientId,
"client_secret": config.OidcClientSecret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if config.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
controller.SetupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) {
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt(ctxkey.TokenId)
token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
if err == nil {
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
}
} else {
userId := c.GetInt(ctxkey.Id)
remainQuota, err = model.GetUserQuota(userId)

View File

@@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) {
"success": true,
"message": "",
"data": gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
"lark_client_id": config.LarkClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
"wechat_login": config.WeChatAuthEnabled,
"server_address": config.ServerAddress,
"turnstile_check": config.TurnstileCheckEnabled,
"turnstile_site_key": config.TurnstileSiteKey,
"top_up_link": config.TopUpLink,
"chat_link": config.ChatLink,
"quota_per_unit": config.QuotaPerUnit,
"display_in_currency": config.DisplayInCurrencyEnabled,
"version": common.Version,
"start_time": common.StartTime,
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
"lark_client_id": config.LarkClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
"wechat_login": config.WeChatAuthEnabled,
"server_address": config.ServerAddress,
"turnstile_check": config.TurnstileCheckEnabled,
"turnstile_site_key": config.TurnstileSiteKey,
"top_up_link": config.TopUpLink,
"chat_link": config.ChatLink,
"quota_per_unit": config.QuotaPerUnit,
"display_in_currency": config.DisplayInCurrencyEnabled,
"oidc": config.OidcEnabled,
"oidc_client_id": config.OidcClientId,
"oidc_well_known": config.OidcWellKnown,
"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint,
"oidc_token_endpoint": config.OidcTokenEndpoint,
"oidc_userinfo_endpoint": config.OidcUserinfoEndpoint,
},
})
return

View File

@@ -34,6 +34,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
fallthrough
case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
case relaymode.Proxy:
err = controller.RelayProxyHelper(c, relayMode)
default:
err = controller.RelayTextHelper(c)
}
@@ -85,12 +87,15 @@ func Relay(c *gin.Context) {
channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId
channelName := c.GetString(ctxkey.ChannelName)
// BUG: bizErr is in race condition
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
}
if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
// BUG: bizErr is in race condition
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
c.JSON(bizErr.StatusCode, gin.H{
"error": bizErr.Error,

3
go.mod
View File

@@ -55,6 +55,7 @@ require (
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-co-op/gocron v1.37.0 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
@@ -87,6 +88,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/smarty/assertions v1.15.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
@@ -96,6 +98,7 @@ require (
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect

22
go.sum
View File

@@ -31,6 +31,7 @@ github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
@@ -67,6 +68,8 @@ github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0Nglqm
github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-co-op/gocron v1.37.0 h1:ZYDJGtQ4OMhTLKOKMIch+/CY70Brbb1dGdooLEhh7b0=
github.com/go-co-op/gocron v1.37.0/go.mod h1:3L/n6BkO7ABj+TrfSVXLRzsP26zmikL4ISkLQ0O8iNY=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -116,6 +119,7 @@ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
@@ -156,7 +160,12 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
@@ -177,6 +186,7 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
@@ -184,7 +194,12 @@ github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYde
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
@@ -198,6 +213,7 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
@@ -217,6 +233,8 @@ go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGX
go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco=
go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
@@ -266,6 +284,7 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.187.0 h1:Mxs7VATVC2v7CY+7Xwm4ndkX71hpElcvx0D1Ji/p1eo=
google.golang.org/api v0.187.0/go.mod h1:KIHlTc4x7N7gKKuVsdmfBXN13yEEWXWFURWY6SBp2gk=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
@@ -296,7 +315,10 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -73,6 +73,8 @@ func main() {
go model.SyncOptions(config.SyncFrequency)
go model.SyncChannelCache(config.SyncFrequency)
}
go model.ScheduleCheckAndDowngrade()
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil {
@@ -112,4 +114,5 @@ func main() {
if err != nil {
logger.FatalLog("failed to start HTTP server: " + err.Error())
}
}

View File

@@ -140,6 +140,12 @@ func TokenAuth() func(c *gin.Context) {
return
}
}
// set channel id for proxy relay
if channelId := c.Param("channelid"); channelId != "" {
c.Set(ctxkey.SpecificChannelId, channelId)
}
c.Next()
}
}

View File

@@ -3,6 +3,7 @@ package model
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
@@ -152,7 +153,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
ifnull := "ifnull"
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -176,7 +181,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
ifnull := "ifnull"
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
if username != "" {
tx = tx.Where("username = ?", username)
}

View File

@@ -1,9 +1,11 @@
package model
import (
"github.com/go-co-op/gocron"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"log"
"strconv"
"strings"
"time"
@@ -28,6 +30,7 @@ func InitOptionMap() {
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
@@ -78,6 +81,21 @@ func InitOptionMap() {
loadOptionsFromDatabase()
}
func ScheduleCheckAndDowngrade() {
s := gocron.NewScheduler(time.UTC)
// 设置每天0点执行
_, err := s.Every(1).Day().At("00:00").Do(checkAndDowngradeUsers)
if err != nil {
log.Fatalf("创建调度任务失败: %v", err)
}
// 开始调度
s.StartBlocking()
log.Printf("开始调度")
}
func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
@@ -130,6 +148,8 @@ func updateOptionMap(key string, value string) (err error) {
config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
config.GitHubOAuthEnabled = boolValue
case "OidcEnabled":
config.OidcEnabled = boolValue
case "WeChatAuthEnabled":
config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled":
@@ -176,6 +196,18 @@ func updateOptionMap(key string, value string) (err error) {
config.LarkClientId = value
case "LarkClientSecret":
config.LarkClientSecret = value
case "OidcClientId":
config.OidcClientId = value
case "OidcClientSecret":
config.OidcClientSecret = value
case "OidcWellKnown":
config.OidcWellKnown = value
case "OidcAuthorizationEndpoint":
config.OidcAuthorizationEndpoint = value
case "OidcTokenEndpoint":
config.OidcTokenEndpoint = value
case "OidcUserinfoEndpoint":
config.OidcUserinfoEndpoint = value
case "Footer":
config.Footer = value
case "SystemName":

View File

@@ -254,14 +254,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
token, err := GetTokenById(tokenId)
if err != nil {
return err
}
if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)
} else {
err = IncreaseUserQuota(token.UserId, -quota)
}
if err != nil {
return err
}
if !token.UnlimitedQuota {
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)

View File

@@ -10,7 +10,9 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm"
"log"
"strings"
"time"
)
const (
@@ -39,6 +41,7 @@ type User struct {
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"bigint;default:0"`
@@ -47,6 +50,7 @@ type User struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
ExpirationDate int64 `json:"expiration_date" gorm:"bigint;default:0;column:expiration_date"` // Expiration date of the user's subscription or account.
}
func GetMaxUserId() int {
@@ -210,6 +214,25 @@ func (user *User) ValidateAndFill() (err error) {
if !okay || user.Status != UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
}
// 校验用户是不是非default,如果是非default,判断到期时间如果过期了降级为default
if !(user.ExpirationDate > 0 && user.Username == "root") {
// 将时间戳转换为 time.Time 类型
expirationTime := time.Unix(user.ExpirationDate, 0)
// 获取当前时间
currentTime := time.Now()
// 比较当前时间和到期时间
if expirationTime.Before(currentTime) {
// 降级为 default
user.Group = "default"
err := DB.Model(user).Updates(user).Error
if err != nil {
fmt.Printf("用户: %s, 降级为 default 时发生错误: %v\n", user.Username, err)
return err
}
fmt.Printf("用户: %s, 特权组过期降为 default\n", user.Username)
}
}
return nil
}
@@ -245,6 +268,14 @@ func (user *User) FillUserByLarkId() error {
return nil
}
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
}
DB.Where(User{OidcId: user.OidcId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -277,6 +308,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool {
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
@@ -435,3 +470,48 @@ func GetUsernameById(id int) (username string) {
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
return username
}
func checkAndDowngradeUsers() {
// 获取昨天的时间戳
yesterdayTimestamp := time.Now().AddDate(0, 0, -1).Unix()
// 获取需要降级的用户ID列表
var userList []int
query := DB.Model(&User{}).
Where("`Group` != ?", "default").
Where("`username` != ?", "root").
Where("`expiration_date` > 0").
Where("`expiration_date` <= ?", yesterdayTimestamp).
Select("id").
Find(&userList)
// 处理查询错误
if query.Error != nil {
log.Printf("查询用户列表失败: %v", query.Error)
return
}
// 如果没有用户需要降级,直接返回
if len(userList) == 0 {
return
}
// 批量降级用户
updateQuery := DB.Model(&User{}).Where("id IN ?", userList).Update("Group", "default")
// 处理更新错误
if updateQuery.Error != nil {
log.Printf("批量更新用户分组失败: %v", updateQuery.Error)
return
}
// 删除已过期用户的Redis缓存
if common.RedisEnabled {
for _, userId := range userList {
err := common.RedisSet(fmt.Sprintf("user_group:%d", userId), "default", time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
log.Printf("更新用户: %d, 权益缓存失败, Error: %v", userId, err)
}
}
}
}

View File

@@ -1,10 +1,11 @@
package monitor
import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model"
)
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
@@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
return true
}
switch err.Type {
case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
case "insufficient_quota", "authentication_error", "permission_error", "forbidden":
return true
}
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
//if strings.Contains(err.Message, "quota") {
// return true
//}
if strings.Contains(err.Message, "credit") {
return true
}
if strings.Contains(err.Message, "balance") {
lowerMessage := strings.ToLower(err.Message)
if strings.Contains(lowerMessage, "your access was terminated") ||
strings.Contains(lowerMessage, "violation of our policies") ||
strings.Contains(lowerMessage, "your credit balance is too low") ||
strings.Contains(lowerMessage, "organization has been disabled") ||
strings.Contains(lowerMessage, "credit") ||
strings.Contains(lowerMessage, "balance") ||
strings.Contains(lowerMessage, "permission denied") ||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
strings.Contains(lowerMessage, "已欠费") {
return true
}
return false

View File

@@ -15,6 +15,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
@@ -58,6 +59,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &deepl.Adaptor{}
case apitype.VertexAI:
return &vertexai.Adaptor{}
case apitype.Proxy:
return &proxy.Adaptor{}
}
return nil
}

View File

@@ -3,6 +3,7 @@ package ali
import (
"bufio"
"encoding/json"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
@@ -59,7 +60,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "text-embedding-v1",
Model: request.Model,
Input: struct {
Texts []string `json:"texts"`
}{
@@ -102,8 +103,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat
StatusCode: resp.StatusCode,
}, nil
}
requestModel := c.GetString(ctxkey.RequestModel)
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
fullTextResponse.Model = requestModel
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -1,6 +1,7 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-3.1-8b-instruct",
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",

View File

@@ -7,8 +7,12 @@ import (
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
switch meta.Mode {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil
default:
}
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
}

View File

@@ -3,6 +3,5 @@ package gemini
// https://ai.google.dev/models/gemini
var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
"gemini-pro", "gemini-1.0-pro", "gemini-1.5-flash", "gemini-1.5-pro", "text-embedding-004", "aqa",
}

View File

@@ -4,9 +4,14 @@ package groq
var ModelList = []string{
"gemma-7b-it",
"llama2-7b-2048",
"llama2-70b-4096",
"mixtral-8x7b-32768",
"llama3-8b-8192",
"llama3-70b-8192",
"gemma2-9b-it",
"llama-3.1-405b-reasoning",
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"llama3-groq-70b-8192-tool-use-preview",
"llama3-groq-8b-8192-tool-use-preview",
"whisper-large-v3",
}

View File

@@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == relaymode.Embeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL)
}
return fullRequestURL, nil
}

View File

@@ -31,6 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
NumPredict: request.MaxTokens,
NumCtx: request.NumCtx,
},
Stream: request.Stream,
}
@@ -118,8 +120,10 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
common.SetEventStreamHeaders(c)
for scanner.Scan() {
data := strings.TrimPrefix(scanner.Text(), "}")
data = data + "}"
data := scanner.Text()
if strings.HasPrefix(data, "}") {
data = strings.TrimPrefix(data, "}") + "}"
}
var ollamaResponse ChatResponse
err := json.Unmarshal([]byte(data), &ollamaResponse)
@@ -157,8 +161,15 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "),
Model: request.Model,
Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
}
}
@@ -201,15 +212,17 @@ func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.Embeddi
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, 1),
Model: "text-embedding-v1",
Model: response.Model,
Usage: model.Usage{TotalTokens: 0},
}
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: 0,
Embedding: response.Embedding,
})
for i, embedding := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: i,
Embedding: embedding,
})
}
return &openAIEmbeddingResponse
}

View File

@@ -7,6 +7,8 @@ type Options struct {
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type Message struct {
@@ -37,11 +39,15 @@ type ChatResponse struct {
}
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Model string `json:"model"`
Input []string `json:"input"`
// Truncate bool `json:"truncate,omitempty"`
Options *Options `json:"options,omitempty"`
// KeepAlive string `json:"keep_alive,omitempty"`
}
type EmbeddingResponse struct {
Error string `json:"error,omitempty"`
Embedding []float64 `json:"embedding,omitempty"`
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embeddings [][]float64 `json:"embeddings"`
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/adaptor/togetherai"
"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
"github.com/songquanpeng/one-api/relay/channeltype"
)
@@ -30,6 +31,7 @@ var CompatibleChannels = []int{
channeltype.DeepSeek,
channeltype.TogetherAI,
channeltype.Novita,
channeltype.SiliconFlow,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
@@ -60,6 +62,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "doubao", doubao.ModelList
case channeltype.Novita:
return "novita", novita.ModelList
case channeltype.SiliconFlow:
return "siliconflow", siliconflow.ModelList
default:
return "openai", ModelList
}

View File

@@ -8,6 +8,9 @@ var ModelList = []string{
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4o", "gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"chatgpt-4o-latest",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",

View File

@@ -110,7 +110,7 @@ func CountTokenMessages(messages []model.Message, model string) int {
if imageUrl["detail"] != nil {
detail = imageUrl["detail"].(string)
}
imageTokens, err := countImageTokens(url, detail)
imageTokens, err := countImageTokens(url, detail, model)
if err != nil {
logger.SysError("error counting image tokens: " + err.Error())
} else {
@@ -134,11 +134,15 @@ const (
lowDetailCost = 85
highDetailCostPerTile = 170
additionalCost = 85
// gpt-4o-mini cost higher than other model
gpt4oMiniLowDetailCost = 2833
gpt4oMiniHighDetailCost = 5667
gpt4oMiniAdditionalCost = 2833
)
// https://platform.openai.com/docs/guides/vision/calculating-costs
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
func countImageTokens(url string, detail string) (_ int, err error) {
func countImageTokens(url string, detail string, model string) (_ int, err error) {
var fetchSize = true
var width, height int
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
@@ -172,6 +176,9 @@ func countImageTokens(url string, detail string) (_ int, err error) {
}
switch detail {
case "low":
if strings.HasPrefix(model, "gpt-4o-mini") {
return gpt4oMiniLowDetailCost, nil
}
return lowDetailCost, nil
case "high":
if fetchSize {
@@ -191,6 +198,9 @@ func countImageTokens(url string, detail string) (_ int, err error) {
height = int(float64(height) * ratio)
}
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
if strings.HasPrefix(model, "gpt-4o-mini") {
return numSquares*gpt4oMiniHighDetailCost + gpt4oMiniAdditionalCost, nil
}
result := numSquares*highDetailCostPerTile + additionalCost
return result, nil
default:

View File

@@ -0,0 +1,89 @@
package proxy
import (
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
var _ adaptor.Adaptor = new(Adaptor)
const channelName = "proxy"
type Adaptor struct{}
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("notimplement")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
for k, v := range resp.Header {
for _, vv := range v {
c.Writer.Header().Set(k, vv)
}
}
c.Writer.WriteHeader(resp.StatusCode)
if _, gerr := io.Copy(c.Writer, resp.Body); gerr != nil {
return nil, &relaymodel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Message: gerr.Error(),
},
}
}
return nil, nil
}
func (a *Adaptor) GetModelList() (models []string) {
return nil
}
func (a *Adaptor) GetChannelName() string {
return channelName
}
// GetRequestURL remove static prefix, and return the real request url to the upstream service
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId)
return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
for k, v := range c.Request.Header {
req.Header.Set(k, v[0])
}
// remove unnecessary headers
req.Header.Del("Host")
req.Header.Del("Content-Length")
req.Header.Del("Accept-Encoding")
req.Header.Del("Connection")
// set authorization header
req.Header.Set("Authorization", meta.APIKey)
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return nil, errors.Errorf("not implement")
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}

View File

@@ -0,0 +1,36 @@
package siliconflow
// https://docs.siliconflow.cn/docs/getting-started
var ModelList = []string{
"deepseek-ai/deepseek-llm-67b-chat",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"deepseek-ai/DeepSeek-V2-Chat",
"THUDM/glm-4-9b-chat",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-large-zh-v1.5",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/Qwen/Qwen1.5-7B-Chat",
"Pro/THUDM/glm-4-9b-chat",
"Pro/THUDM/chatglm3-6b",
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
"Pro/01-ai/Yi-1.5-6B-Chat",
"Pro/google/gemma-2-9b-it",
"Pro/internlm/internlm2_5-7b-chat",
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
}

View File

@@ -1,7 +1,13 @@
package stepfun
var ModelList = []string{
"step-1-8k",
"step-1-32k",
"step-1-128k",
"step-1-256k",
"step-1-flash",
"step-2-16k",
"step-1v-8k",
"step-1v-32k",
"step-1-200k",
"step-1x-medium",
}

View File

@@ -18,6 +18,7 @@ const (
Cloudflare
DeepL
VertexAI
Proxy
Dummy // this one is only for count, do not add any channel after this
)

View File

@@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{
"720x1280": 1,
"1280x720": 1,
},
"step-1x-medium": {
"256x256": 1,
"512x512": 1,
"768x768": 1,
"1024x1024": 1,
"1280x800": 1,
"800x1280": 1,
},
}
var ImageGenerationAmounts = map[string][2]int{
@@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{
"ali-stable-diffusion-v1.5": {1, 4}, // Ali
"wanx-v1": {1, 4}, // Ali
"cogview-3": {1, 1},
"step-1x-medium": {1, 1},
}
var ImagePromptLengthLimitations = map[string]int{
@@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{
"ali-stable-diffusion-v1.5": 4000,
"wanx-v1": 4000,
"cogview-3": 833,
"step-1x-medium": 4000,
}
var ImageOriginModelName = map[string]string{

View File

@@ -28,15 +28,19 @@ var ModelRatio = map[string]float64{
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.005 / 1K tokens
"chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
@@ -96,12 +100,11 @@ var ModelRatio = map[string]float64{
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro": 1,
"gemini-1.5-flash": 1,
"gemini-1.5-pro": 1,
"aqa": 1,
// https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
@@ -156,20 +159,29 @@ var ModelRatio = map[string]float64{
"mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD,
// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed
"llama3-70b-8192": 0.59 / 1000 * USD,
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
"llama3-8b-8192": 0.05 / 1000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD,
"llama2-70b-4096": 0.64 / 1000 * USD,
"llama2-7b-2048": 0.1 / 1000 * USD,
"gemma-7b-it": 0.07 / 1000000 * USD,
"mixtral-8x7b-32768": 0.24 / 1000000 * USD,
"llama3-8b-8192": 0.05 / 1000000 * USD,
"llama3-70b-8192": 0.59 / 1000000 * USD,
"gemma2-9b-it": 0.20 / 1000000 * USD,
"llama-3.1-405b-reasoning": 0.89 / 1000000 * USD,
"llama-3.1-70b-versatile": 0.59 / 1000000 * USD,
"llama-3.1-8b-instant": 0.05 / 1000000 * USD,
"llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD,
"llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD,
// https://platform.lingyiwanwu.com/docs#-计费单元
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
"yi-vl-plus": 6.0 / 1000 * RMB,
// stepfun todo
"step-1v-32k": 0.024 * RMB,
"step-1-32k": 0.024 * RMB,
"step-1-200k": 0.15 * RMB,
// https://platform.stepfun.com/docs/pricing/details
"step-1-8k": 0.005 / 1000 * RMB,
"step-1-32k": 0.015 / 1000 * RMB,
"step-1-128k": 0.040 / 1000 * RMB,
"step-1-256k": 0.095 / 1000 * RMB,
"step-1-flash": 0.001 / 1000 * RMB,
"step-2-16k": 0.038 / 1000 * RMB,
"step-1v-8k": 0.005 / 1000 * RMB,
"step-1v-32k": 0.015 / 1000 * RMB,
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
@@ -195,8 +207,10 @@ var CompletionRatio = map[string]float64{
"llama3-70b-8192(33)": 0.0035 / 0.00265,
}
var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64
var (
DefaultModelRatio map[string]float64
DefaultCompletionRatio map[string]float64
)
func init() {
DefaultModelRatio = make(map[string]float64)
@@ -308,6 +322,9 @@ func GetCompletionRatio(name string, channelType int) float64 {
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" {
return 4
}
if strings.HasPrefix(name, "gpt-4-turbo") ||
strings.HasPrefix(name, "gpt-4o") ||
strings.HasSuffix(name, "preview") {
@@ -315,6 +332,9 @@ func GetCompletionRatio(name string, channelType int) float64 {
}
return 2
}
if name == "chatgpt-4o-latest" {
return 3
}
if strings.HasPrefix(name, "claude-3") {
return 5
}

View File

@@ -44,5 +44,7 @@ const (
Doubao
Novita
VertextAI
Proxy
SiliconFlow
Dummy
)

View File

@@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
apiType = apitype.DeepL
case VertextAI:
apiType = apitype.VertexAI
case Proxy:
apiType = apitype.Proxy
}
return apiType

View File

@@ -44,6 +44,8 @@ var ChannelBaseURLs = []string{
"https://ark.cn-beijing.volces.com", // 40
"https://api.novita.ai/v3/openai", // 41
"", // 42
"", // 43
"https://api.siliconflow.cn", // 44
}
func init() {

41
relay/controller/proxy.go Normal file
View File

@@ -0,0 +1,41 @@
// Package controller is a package for handling the relay controller
package controller
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
// RelayProxyHelper is a helper function to proxy the request to the upstream service
func RelayProxyHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(meta)
resp, err := adaptor.DoRequest(c, meta, c.Request.Body)
if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
// do response
_, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
return respErr
}
return nil
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing"
@@ -31,9 +32,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.IsStream = textRequest.Stream
// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
@@ -55,30 +55,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
adaptor.Init(meta)
// get request body
var requestBody io.Reader
if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
requestBody, err := getRequestBody(c, meta, textRequest, adaptor)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
// do request
@@ -103,3 +82,26 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil
}
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
// no need to convert request for openai
return c.Request.Body, nil
}
// get request body
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error())
return nil, err
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
return nil, err
}
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil
}

View File

@@ -18,11 +18,12 @@ type Meta struct {
UserId int
Group string
ModelMapping map[string]string
BaseURL string
APIKey string
APIType int
Config model.ChannelConfig
IsStream bool
// BaseURL is the proxy url set in the channel config
BaseURL string
APIKey string
APIType int
Config model.ChannelConfig
IsStream bool
// OriginModelName is the model name from the raw user request
OriginModelName string
// ActualModelName is the model name after mapping

View File

@@ -1,7 +1,15 @@
package model
type ResponseFormat struct {
Type string `json:"type,omitempty"`
Type string `json:"type,omitempty"`
JsonSchema *JSONSchema `json:"json_schema,omitempty"`
}
type JSONSchema struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Schema map[string]interface{} `json:"schema,omitempty"`
Strict *bool `json:"strict,omitempty"`
}
type GeneralOpenAIRequest struct {
@@ -29,6 +37,7 @@ type GeneralOpenAIRequest struct {
Dimensions int `json:"dimensions,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {

View File

@@ -11,4 +11,6 @@ const (
AudioSpeech
AudioTranscription
AudioTranslation
// Proxy is a special relay mode for proxying requests to custom upstream
Proxy
)

View File

@@ -24,6 +24,8 @@ func GetByPath(path string) int {
relayMode = AudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = AudioTranslation
} else if strings.HasPrefix(path, "/v1/oneapi/proxy") {
relayMode = Proxy
}
return relayMode
}

View File

@@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth)
apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)

View File

@@ -19,6 +19,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
{
relayV1Router.Any("/oneapi/proxy/:channelid/*target", controller.Relay)
relayV1Router.POST("/completions", controller.Relay)
relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.Relay)

View File

@@ -11,12 +11,14 @@ import EditToken from '../pages/Token/EditToken';
const COPY_OPTIONS = [
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' }
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
];
const OPEN_LINK_OPTIONS = [
{ key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' }
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' }
];
function renderTimestamp(timestamp) {
@@ -60,7 +62,12 @@ const TokensTable = () => {
onOpenLink('next-mj');
}
},
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' }
{ node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' },
{
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
onOpenLink('lobechat');
}
}
];
const columns = [
@@ -177,6 +184,11 @@ const TokensTable = () => {
node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => {
onOpenLink('opencat', record.key);
}
},
{
node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => {
onOpenLink('lobechat');
}
}
]
}
@@ -382,6 +394,9 @@ const TokensTable = () => {
case 'next-mj':
url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
break;
case 'lobechat':
url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}"/v1"}}}`;
break;
default:
if (!chatLink) {
showError('管理员未设置聊天链接');

View File

@@ -1,10 +1,13 @@
export const CHANNEL_OPTIONS = [
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
{ key: 33, text: 'AWS', value: 33, color: 'black' },
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
{ key: 24, text: 'Google Gemini', value: 24, color: 'orange' },
{ key: 28, text: 'Mistral AI', value: 28, color: 'orange' },
{ key: 41, text: 'Novita', value: 41, color: 'purple' },
{ key: 40, text: '字节跳动豆包', value: 40, color: 'blue' },
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
@@ -17,6 +20,16 @@ export const CHANNEL_OPTIONS = [
{ key: 29, text: 'Groq', value: 29, color: 'orange' },
{ key: 30, text: 'Ollama', value: 30, color: 'black' },
{ key: 31, text: '零一万物', value: 31, color: 'green' },
{ key: 32, text: '阶跃星辰', value: 32, color: 'blue' },
{ key: 34, text: 'Coze', value: 34, color: 'blue' },
{ key: 35, text: 'Cohere', value: 35, color: 'blue' },
{ key: 36, text: 'DeepSeek', value: 36, color: 'black' },
{ key: 37, text: 'Cloudflare', value: 37, color: 'orange' },
{ key: 38, text: 'DeepL', value: 38, color: 'black' },
{ key: 39, text: 'together.ai', value: 39, color: 'blue' },
{ key: 42, text: 'VertexAI', value: 42, color: 'blue' },
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },
@@ -34,4 +47,4 @@ export const CHANNEL_OPTIONS = [
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
CHANNEL_OPTIONS[i].label = CHANNEL_OPTIONS[i].text;
}
}

View File

@@ -17,6 +17,7 @@
"@tabler/icons-react": "^2.44.0",
"apexcharts": "3.35.3",
"axios": "^0.27.2",
"date-fns": "^3.6.0",
"dayjs": "^1.11.10",
"formik": "^2.2.9",
"framer-motion": "^6.3.16",
@@ -27,6 +28,7 @@
"prop-types": "^15.8.1",
"react": "^18.2.0",
"react-apexcharts": "1.4.0",
"react-datepicker": "^7.3.0",
"react-device-detect": "^2.2.2",
"react-dom": "^18.2.0",
"react-perfect-scrollbar": "^1.5.8",

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 5.4 KiB

After

Width:  |  Height:  |  Size: 4.3 KiB

View File

@@ -0,0 +1,7 @@
<svg t="1723135116886" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg"
p-id="10969" width="200" height="200">
<path d="M512 960C265 960 64 759 64 512S265 64 512 64s448 201 448 448-201 448-448 448z m0-882.6c-239.7 0-434.6 195-434.6 434.6s195 434.6 434.6 434.6 434.6-195 434.6-434.6S751.7 77.4 512 77.4z"
p-id="10970" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="60"></path>
<path d="M197.7 512c0-78.3 31.6-98.8 87.2-98.8 56.2 0 87.2 20.5 87.2 98.8s-31 98.8-87.2 98.8c-55.7 0-87.2-20.5-87.2-98.8z m130.4 0c0-46.8-7.8-64.5-43.2-64.5-35.2 0-42.9 17.7-42.9 64.5 0 47.1 7.8 63.7 42.9 63.7 35.4 0 43.2-16.6 43.2-63.7zM409.7 415.9h42.1V608h-42.1V415.9zM653.9 512c0 74.2-37.1 96.1-93.6 96.1h-65.9V415.9h65.9c56.5 0 93.6 16.1 93.6 96.1z m-43.5 0c0-49.3-17.7-60.6-52.3-60.6h-21.6v120.7h21.6c35.4 0 52.3-13.3 52.3-60.1zM686.5 512c0-74.2 36.3-98.8 92.7-98.8 18.3 0 33.2 2.2 44.8 6.4v36.3c-11.9-4.2-26-6.6-42.1-6.6-34.6 0-49.8 15.5-49.8 62.6 0 50.1 15.2 62.6 49.3 62.6 15.8 0 30.2-2.2 44.8-7.5v36c-11.3 4.7-28.5 8-46.8 8-56.1-0.2-92.9-18.7-92.9-99z"
p-id="10971" fill="#2c2c2c" stroke="#2c2c2c" stroke-width="20"></path>
</svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -22,7 +22,12 @@ const config = {
turnstile_site_key: '',
version: '',
wechat_login: false,
wechat_qrcode: ''
wechat_qrcode: '',
oidc: false,
oidc_client_id: '',
oidc_authorization_endpoint: '',
oidc_token_endpoint: '',
oidc_userinfo_endpoint: '',
}
};

View File

@@ -167,6 +167,18 @@ export const CHANNEL_OPTIONS = {
value: 42,
color: 'primary'
},
43: {
key: 43,
text: 'Proxy',
value: 43,
color: 'primary'
},
44: {
key: 44,
text: 'SiliconFlow',
value: 44,
color: 'primary'
},
41: {
key: 41,
text: 'Novita',

View File

@@ -70,6 +70,28 @@ const useLogin = () => {
}
};
const oidcLogin = async (code, state) => {
try {
const res = await API.get(`/api/oauth/oidc?code=${code}&state=${state}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/panel');
} else {
dispatch({ type: LOGIN, payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/panel');
}
}
return { success, message };
} catch (err) {
// 请求失败,设置错误信息
return { success: false, message: '' };
}
}
const wechatLogin = async (code) => {
try {
const res = await API.get(`/api/oauth/wechat?code=${code}`);
@@ -94,7 +116,7 @@ const useLogin = () => {
navigate('/');
};
return { login, logout, githubLogin, wechatLogin, larkLogin };
return { login, logout, githubLogin, wechatLogin, larkLogin,oidcLogin };
};
export default useLogin;

View File

@@ -9,6 +9,7 @@ const AuthLogin = Loadable(lazy(() => import('views/Authentication/Auth/Login'))
const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register')));
const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth')));
const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth')));
const OidcOAuth = Loadable(lazy(() => import('views/Authentication/Auth/OidcOAuth')));
const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword')));
const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword')));
const Home = Loadable(lazy(() => import('views/Home')));
@@ -53,6 +54,10 @@ const OtherRoutes = {
path: '/oauth/lark',
element: <LarkOAuth />
},
{
path: 'oauth/oidc',
element: <OidcOAuth />
},
{
path: '/404',
element: <NotFoundView />

View File

@@ -98,6 +98,21 @@ export async function onLarkOAuthClicked(lark_client_id) {
window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`);
}
export async function onOidcClicked(auth_url, client_id, openInNewTab = false) {
const state = await getOAuthState();
if (!state) return;
const redirect_uri = `${window.location.origin}/oauth/oidc`;
const response_type = "code";
const scope = "openid profile email";
const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`;
if (openInNewTab) {
window.open(url);
} else
{
window.location.href = url;
}
}
export function isAdmin() {
let user = localStorage.getItem('user');
if (!user) return false;

View File

@@ -0,0 +1,94 @@
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import React, { useEffect, useState } from 'react';
import { showError } from 'utils/common';
import useLogin from 'hooks/useLogin';
// material-ui
import { useTheme } from '@mui/material/styles';
import { Grid, Stack, Typography, useMediaQuery, CircularProgress } from '@mui/material';
// project imports
import AuthWrapper from '../AuthWrapper';
import AuthCardWrapper from '../AuthCardWrapper';
import Logo from 'ui-component/Logo';
// assets
// ================================|| AUTH3 - LOGIN ||================================ //
const OidcOAuth = () => {
const theme = useTheme();
const matchDownSM = useMediaQuery(theme.breakpoints.down('md'));
const [searchParams] = useSearchParams();
const [prompt, setPrompt] = useState('处理中...');
const { oidcLogin } = useLogin();
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
const { success, message } = await oidcLogin(code, state);
if (!success) {
if (message) {
showError(message);
}
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
await new Promise((resolve) => setTimeout(resolve, 2000));
navigate('/login');
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, 2000));
await sendCode(code, state, count);
}
};
useEffect(() => {
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (
<AuthWrapper>
<Grid container direction="column" justifyContent="flex-end">
<Grid item xs={12}>
<Grid container justifyContent="center" alignItems="center" sx={{ minHeight: 'calc(100vh - 136px)' }}>
<Grid item sx={{ m: { xs: 1, sm: 3 }, mb: 0 }}>
<AuthCardWrapper>
<Grid container spacing={2} alignItems="center" justifyContent="center">
<Grid item sx={{ mb: 3 }}>
<Link to="#">
<Logo />
</Link>
</Grid>
<Grid item xs={12}>
<Grid container direction={matchDownSM ? 'column-reverse' : 'row'} alignItems="center" justifyContent="center">
<Grid item>
<Stack alignItems="center" justifyContent="center" spacing={1}>
<Typography color={theme.palette.primary.main} gutterBottom variant={matchDownSM ? 'h3' : 'h2'}>
OIDC 登录
</Typography>
</Stack>
</Grid>
</Grid>
</Grid>
<Grid item xs={12} container direction="column" justifyContent="center" alignItems="center" style={{ height: '200px' }}>
<CircularProgress />
<Typography variant="h3" paddingTop={'20px'}>
{prompt}
</Typography>
</Grid>
</Grid>
</AuthCardWrapper>
</Grid>
</Grid>
</Grid>
</Grid>
</AuthWrapper>
);
};
export default OidcOAuth;

View File

@@ -36,7 +36,8 @@ import VisibilityOff from '@mui/icons-material/VisibilityOff';
import Github from 'assets/images/icons/github.svg';
import Wechat from 'assets/images/icons/wechat.svg';
import Lark from 'assets/images/icons/lark.svg';
import { onGitHubOAuthClicked, onLarkOAuthClicked } from 'utils/common';
import OIDC from 'assets/images/icons/oidc.svg';
import { onGitHubOAuthClicked, onLarkOAuthClicked, onOidcClicked } from 'utils/common';
// ============================|| FIREBASE - LOGIN ||============================ //
@@ -50,7 +51,7 @@ const LoginForm = ({ ...others }) => {
// const [checked, setChecked] = useState(true);
let tripartiteLogin = false;
if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id) {
if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id || siteInfo.oidc) {
tripartiteLogin = true;
}
@@ -145,6 +146,29 @@ const LoginForm = ({ ...others }) => {
</AnimateButton>
</Grid>
)}
{siteInfo.oidc && (
<Grid item xs={12}>
<AnimateButton>
<Button
disableElevation
fullWidth
onClick={() => onOidcClicked(siteInfo.oidc_authorization_endpoint,siteInfo.oidc_client_id)}
size="large"
variant="outlined"
sx={{
color: 'grey.700',
backgroundColor: theme.palette.grey[50],
borderColor: theme.palette.grey[100]
}}
>
<Box sx={{ mr: { xs: 1, sm: 2, width: 20 }, display: 'flex', alignItems: 'center' }}>
<img src={OIDC} alt="Lark" width={25} height={25} style={{ marginRight: matchDownSM ? 8 : 16 }} />
</Box>
使用 OIDC 登录
</Button>
</AnimateButton>
</Grid>
)}
<Grid item xs={12}>
<Box
sx={{

View File

@@ -20,7 +20,7 @@ import SubCard from 'ui-component/cards/SubCard';
import { IconBrandWechat, IconBrandGithub, IconMail } from '@tabler/icons-react';
import Label from 'ui-component/Label';
import { API } from 'utils/api';
import { showError, showSuccess } from 'utils/common';
import { onOidcClicked, showError, showSuccess } from 'utils/common';
import { onGitHubOAuthClicked, onLarkOAuthClicked, copy } from 'utils/common';
import * as Yup from 'yup';
import WechatModal from 'views/Authentication/AuthForms/WechatModal';
@@ -28,6 +28,7 @@ import { useSelector } from 'react-redux';
import EmailModal from './component/EmailModal';
import Turnstile from 'react-turnstile';
import { ReactComponent as Lark } from 'assets/images/icons/lark.svg';
import { ReactComponent as OIDC } from 'assets/images/icons/oidc.svg';
const validationSchema = Yup.object().shape({
username: Yup.string().required('用户名 不能为空').min(3, '用户名 不能小于 3 个字符'),
@@ -123,6 +124,15 @@ export default function Profile() {
loadUser().then();
}, [status]);
function getOidcId(){
if (!inputs.oidc_id) return '';
let oidc_id = inputs.oidc_id;
if (inputs.oidc_id.length > 8) {
oidc_id = inputs.oidc_id.slice(0, 6) + '...' + inputs.oidc_id.slice(-6);
}
return oidc_id;
}
return (
<>
<UserCard>
@@ -141,6 +151,9 @@ export default function Profile() {
<Label variant="ghost" color={inputs.lark_id ? 'primary' : 'default'}>
<SvgIcon component={Lark} inheritViewBox="0 0 24 24" /> {inputs.lark_id || '未绑定'}
</Label>
<Label variant="ghost" color={inputs.oidc_id ? 'primary' : 'default'}>
<SvgIcon component={OIDC} inheritViewBox="0 0 24 24" /> {getOidcId() || '未绑定'}
</Label>
</Stack>
<SubCard title="个人信息">
<Grid container spacing={2}>
@@ -216,6 +229,13 @@ export default function Profile() {
</Button>
</Grid>
)}
{status.oidc && !inputs.oidc_id && (
<Grid xs={12} md={4}>
<Button variant="contained" onClick={() => onOidcClicked(status.oidc_authorization_endpoint,status.oidc_client_id,true)}>
绑定 OIDC 账号
</Button>
</Grid>
)}
<Grid xs={12} md={4}>
<Button
variant="contained"

View File

@@ -33,6 +33,13 @@ const SystemSetting = () => {
GitHubClientSecret: '',
LarkClientId: '',
LarkClientSecret: '',
OidcEnabled: '',
OidcWellKnown: '',
OidcClientId: '',
OidcClientSecret: '',
OidcAuthorizationEndpoint: '',
OidcTokenEndpoint: '',
OidcUserinfoEndpoint: '',
Notice: '',
SMTPServer: '',
SMTPPort: '',
@@ -94,6 +101,7 @@ const SystemSetting = () => {
case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled':
case 'RegisterEnabled':
case 'OidcEnabled':
value = inputs[key] === 'true' ? 'false' : 'true';
break;
default:
@@ -142,8 +150,15 @@ const SystemSetting = () => {
name === 'MessagePusherAddress' ||
name === 'MessagePusherToken' ||
name === 'LarkClientId' ||
name === 'LarkClientSecret'
) {
name === 'LarkClientSecret' ||
name === 'OidcClientId' ||
name === 'OidcClientSecret' ||
name === 'OidcWellKnown' ||
name === 'OidcAuthorizationEndpoint' ||
name === 'OidcTokenEndpoint' ||
name === 'OidcUserinfoEndpoint'
)
{
setInputs((inputs) => ({ ...inputs, [name]: value }));
} else {
await updateOption(name, value);
@@ -225,6 +240,43 @@ const SystemSetting = () => {
}
};
const submitOidc = async () => {
if (inputs.OidcWellKnown !== '') {
if (!inputs.OidcWellKnown.startsWith('http://') && !inputs.OidcWellKnown.startsWith('https://')) {
showError('Well-Known URL 必须以 http:// 或 https:// 开头');
return;
}
try {
const res = await API.get(inputs.OidcWellKnown);
inputs.OidcAuthorizationEndpoint = res.data['authorization_endpoint'];
inputs.OidcTokenEndpoint = res.data['token_endpoint'];
inputs.OidcUserinfoEndpoint = res.data['userinfo_endpoint'];
showSuccess('获取 OIDC 配置成功!');
} catch (err) {
showError("获取 OIDC 配置失败,请检查网络状况和 Well-Known URL 是否正确");
}
}
if (originInputs['OidcWellKnown'] !== inputs.OidcWellKnown) {
await updateOption('OidcWellKnown', inputs.OidcWellKnown);
}
if (originInputs['OidcClientId'] !== inputs.OidcClientId) {
await updateOption('OidcClientId', inputs.OidcClientId);
}
if (originInputs['OidcClientSecret'] !== inputs.OidcClientSecret && inputs.OidcClientSecret !== '') {
await updateOption('OidcClientSecret', inputs.OidcClientSecret);
}
if (originInputs['OidcAuthorizationEndpoint'] !== inputs.OidcAuthorizationEndpoint) {
await updateOption('OidcAuthorizationEndpoint', inputs.OidcAuthorizationEndpoint);
}
if (originInputs['OidcTokenEndpoint'] !== inputs.OidcTokenEndpoint) {
await updateOption('OidcTokenEndpoint', inputs.OidcTokenEndpoint);
}
if (originInputs['OidcUserinfoEndpoint'] !== inputs.OidcUserinfoEndpoint) {
await updateOption('OidcUserinfoEndpoint', inputs.OidcUserinfoEndpoint);
}
};
return (
<>
<Stack spacing={2}>
@@ -291,6 +343,12 @@ const SystemSetting = () => {
control={<Checkbox checked={inputs.GitHubOAuthEnabled === 'true'} onChange={handleInputChange} name="GitHubOAuthEnabled" />}
/>
</Grid>
<Grid xs={12} md={3}>
<FormControlLabel
label="允许通过 OIDC 登录 & 注册"
control={<Checkbox checked={inputs.OidcEnabled === 'true'} onChange={handleInputChange} name="OidcEnabled" />}
/>
</Grid>
<Grid xs={12} md={3}>
<FormControlLabel
label="允许通过微信登录 & 注册"
@@ -616,6 +674,117 @@ const SystemSetting = () => {
</Grid>
</Grid>
</SubCard>
<SubCard
title="配置 OIDC"
subTitle={
<span>
用以支持通过 OIDC 登录例如 OktaAuth0 等兼容 OIDC 协议的 IdP
</span>
}
>
<Grid container spacing={ { xs: 3, sm: 2, md: 4 } }>
<Grid xs={ 12 } md={ 12 }>
<Alert severity="info" sx={ { wordWrap: 'break-word' } }>
主页链接填 <code>{ inputs.ServerAddress }</code>
重定向 URL <code>{ `${ inputs.ServerAddress }/oauth/oidc` }</code>
</Alert> <br />
<Alert severity="info" sx={ { wordWrap: 'break-word' } }>
若你的 OIDC Provider 支持 Discovery Endpoint你可以仅填写 OIDC Well-Known URL系统会自动获取 OIDC 配置
</Alert>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcClientId">Client ID</InputLabel>
<OutlinedInput
id="OidcClientId"
name="OidcClientId"
value={ inputs.OidcClientId || '' }
onChange={ handleInputChange }
label="Client ID"
placeholder="输入 OIDC 的 Client ID"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcClientSecret">Client Secret</InputLabel>
<OutlinedInput
id="OidcClientSecret"
name="OidcClientSecret"
value={ inputs.OidcClientSecret || '' }
onChange={ handleInputChange }
label="Client Secret"
placeholder="敏感信息不会发送到前端显示"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcWellKnown">Well-Known URL</InputLabel>
<OutlinedInput
id="OidcWellKnown"
name="OidcWellKnown"
value={ inputs.OidcWellKnown || '' }
onChange={ handleInputChange }
label="Well-Known URL"
placeholder="请输入 OIDC 的 Well-Known URL"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcAuthorizationEndpoint">Authorization Endpoint</InputLabel>
<OutlinedInput
id="OidcAuthorizationEndpoint"
name="OidcAuthorizationEndpoint"
value={ inputs.OidcAuthorizationEndpoint || '' }
onChange={ handleInputChange }
label="Authorization Endpoint"
placeholder="输入 OIDC 的 Authorization Endpoint"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcTokenEndpoint">Token Endpoint</InputLabel>
<OutlinedInput
id="OidcTokenEndpoint"
name="OidcTokenEndpoint"
value={ inputs.OidcTokenEndpoint || '' }
onChange={ handleInputChange }
label="Token Endpoint"
placeholder="输入 OIDC 的 Token Endpoint"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 } md={ 6 }>
<FormControl fullWidth>
<InputLabel htmlFor="OidcUserinfoEndpoint">Userinfo Endpoint</InputLabel>
<OutlinedInput
id="OidcUserinfoEndpoint"
name="OidcUserinfoEndpoint"
value={ inputs.OidcUserinfoEndpoint || '' }
onChange={ handleInputChange }
label="Userinfo Endpoint"
placeholder="输入 OIDC 的 Userinfo Endpoint"
disabled={ loading }
/>
</FormControl>
</Grid>
<Grid xs={ 12 }>
<Button variant="contained" onClick={ submitOidc }>
保存 OIDC 设置
</Button>
</Grid>
</Grid>
</SubCard>
<SubCard
title="配置 Message Pusher"
subTitle={

View File

@@ -32,7 +32,8 @@ const COPY_OPTIONS = [
encode: false
},
{ key: 'ama', text: 'BotGem', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true },
{ key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true }
{ key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true },
{ key: 'lobechat', text: 'LobeChat', url: 'https://lobehub.com/?settings={"keyVaults":{"openai":{"apiKey":"user-key","baseURL":"https://your-proxy.com/v1"}}}', encode: true }
];
function replacePlaceholders(text, key, serverAddress) {

View File

@@ -2,7 +2,9 @@ import PropTypes from 'prop-types';
import * as Yup from 'yup';
import { Formik } from 'formik';
import { useTheme } from '@mui/material/styles';
import { useState, useEffect } from 'react';
import { useState, useEffect, useCallback } from 'react';
import { format } from 'date-fns';
import {
Dialog,
DialogTitle,
@@ -17,7 +19,11 @@ import {
Select,
MenuItem,
IconButton,
FormHelperText
FormHelperText,
TextField,
Typography,
Switch,
FormControlLabel
} from '@mui/material';
import Visibility from '@mui/icons-material/Visibility';
@@ -44,6 +50,17 @@ const validationSchema = Yup.object().shape({
is: false,
then: Yup.number().min(0, '额度 不能小于 0'),
otherwise: Yup.number()
}),
expiration_date: Yup.mixed().when('group', {
is: (group) => group !== 'default',
then: Yup.mixed().test(
'expiration_date-required',
'到期时间 不能为空',
function (value) {
const { expiration_date } = this.parent;
return expiration_date === -1 || !!expiration_date;
}
),
})
});
@@ -53,7 +70,8 @@ const originInputs = {
display_name: '',
password: '',
group: 'default',
quota: 0
quota: 0,
expiration_date: null
};
const EditModal = ({ open, userId, onCancel, onOk }) => {
@@ -65,6 +83,12 @@ const EditModal = ({ open, userId, onCancel, onOk }) => {
const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
setSubmitting(true);
// 将到期时间转换为 Unix 时间戳
if (values.expiration_date && values.expiration_date !== -1) {
const date = new Date(values.expiration_date);
values.expiration_date = Math.floor(date.getTime() / 1000); // 转换为秒级的 Unix 时间戳
}
let res;
if (values.is_edit) {
res = await API.put(`/api/user/`, { ...values, id: parseInt(userId) });
@@ -95,16 +119,23 @@ const EditModal = ({ open, userId, onCancel, onOk }) => {
event.preventDefault();
};
const loadUser = async () => {
const loadUser = useCallback(async () => {
let res = await API.get(`/api/user/${userId}`);
const { success, message, data } = res.data;
if (success) {
data.is_edit = true;
// 将 Unix 时间戳转换为日期字符串
if (data.expiration_date && data.expiration_date !== -1) {
const date = new Date(data.expiration_date * 1000); // 转换为毫秒级的时间戳
data.expiration_date = format(date, 'yyyy-MM-dd'); // 格式化为 date 格式
}
setInputs(data);
} else {
showError(message);
}
};
}, [userId]);
const fetchGroups = async () => {
try {
@@ -122,159 +153,203 @@ const EditModal = ({ open, userId, onCancel, onOk }) => {
} else {
setInputs(originInputs);
}
}, [userId]);
}, [userId, loadUser]);
return (
<Dialog open={open} onClose={onCancel} fullWidth maxWidth={'md'}>
<DialogTitle sx={{ margin: '0px', fontWeight: 700, lineHeight: '1.55556', padding: '24px', fontSize: '1.125rem' }}>
{userId ? '编辑用户' : '新建用户'}
</DialogTitle>
<Divider />
<DialogContent>
<Formik initialValues={inputs} enableReinitialize validationSchema={validationSchema} onSubmit={submit}>
{({ errors, handleBlur, handleChange, handleSubmit, touched, values, isSubmitting }) => (
<form noValidate onSubmit={handleSubmit}>
<FormControl fullWidth error={Boolean(touched.username && errors.username)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-username-label">用户名</InputLabel>
<OutlinedInput
id="channel-username-label"
label="用户名"
type="text"
value={values.username}
name="username"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'username' }}
aria-describedby="helper-text-channel-username-label"
/>
{touched.username && errors.username && (
<FormHelperText error id="helper-tex-channel-username-label">
{errors.username}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.display_name && errors.display_name)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-display_name-label">显示名称</InputLabel>
<OutlinedInput
id="channel-display_name-label"
label="显示名称"
type="text"
value={values.display_name}
name="display_name"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'display_name' }}
aria-describedby="helper-text-channel-display_name-label"
/>
{touched.display_name && errors.display_name && (
<FormHelperText error id="helper-tex-channel-display_name-label">
{errors.display_name}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.password && errors.password)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-password-label">密码</InputLabel>
<OutlinedInput
id="channel-password-label"
label="密码"
type={showPassword ? 'text' : 'password'}
value={values.password}
name="password"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'password' }}
endAdornment={
<InputAdornment position="end">
<IconButton
aria-label="toggle password visibility"
onClick={handleClickShowPassword}
onMouseDown={handleMouseDownPassword}
edge="end"
size="large"
>
{showPassword ? <Visibility /> : <VisibilityOff />}
</IconButton>
</InputAdornment>
}
aria-describedby="helper-text-channel-password-label"
/>
{touched.password && errors.password && (
<FormHelperText error id="helper-tex-channel-password-label">
{errors.password}
</FormHelperText>
)}
</FormControl>
{values.is_edit && (
<>
<FormControl fullWidth error={Boolean(touched.quota && errors.quota)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-quota-label">额度</InputLabel>
<Dialog open={open} onClose={onCancel} fullWidth maxWidth={'md'}>
<DialogTitle sx={{ margin: '0px', fontWeight: 700, lineHeight: '1.55556', padding: '24px', fontSize: '1.125rem' }}>
{userId ? '编辑用户' : '新建用户'}
</DialogTitle>
<Divider />
<DialogContent>
<Formik initialValues={inputs} enableReinitialize validationSchema={validationSchema} onSubmit={submit}>
{({ errors, handleBlur, handleChange, handleSubmit, setFieldValue, touched, values, isSubmitting }) => (
<form noValidate onSubmit={handleSubmit}>
<FormControl fullWidth error={Boolean(touched.username && errors.username)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-username-label">用户名</InputLabel>
<OutlinedInput
id="channel-quota-label"
label="额度"
type="number"
value={values.quota}
name="quota"
endAdornment={<InputAdornment position="end">{renderQuotaWithPrompt(values.quota)}</InputAdornment>}
onBlur={handleBlur}
onChange={handleChange}
aria-describedby="helper-text-channel-quota-label"
disabled={values.unlimited_quota}
id="channel-username-label"
label="用户名"
type="text"
value={values.username}
name="username"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'username' }}
aria-describedby="helper-text-channel-username-label"
/>
{touched.quota && errors.quota && (
<FormHelperText error id="helper-tex-channel-quota-label">
{errors.quota}
</FormHelperText>
{touched.username && errors.username && (
<FormHelperText error id="helper-tex-channel-username-label">
{errors.username}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.group && errors.group)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-group-label">分组</InputLabel>
<Select
id="channel-group-label"
label="分组"
value={values.group}
name="group"
onBlur={handleBlur}
onChange={handleChange}
MenuProps={{
PaperProps: {
style: {
maxHeight: 200
}
<FormControl fullWidth error={Boolean(touched.display_name && errors.display_name)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-display_name-label">显示名称</InputLabel>
<OutlinedInput
id="channel-display_name-label"
label="显示名称"
type="text"
value={values.display_name}
name="display_name"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'display_name' }}
aria-describedby="helper-text-channel-display_name-label"
/>
{touched.display_name && errors.display_name && (
<FormHelperText error id="helper-tex-channel-display_name-label">
{errors.display_name}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.password && errors.password)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-password-label">密码</InputLabel>
<OutlinedInput
id="channel-password-label"
label="密码"
type={showPassword ? 'text' : 'password'}
value={values.password}
name="password"
onBlur={handleBlur}
onChange={handleChange}
inputProps={{ autoComplete: 'password' }}
endAdornment={
<InputAdornment position="end">
<IconButton
aria-label="toggle password visibility"
onClick={handleClickShowPassword}
onMouseDown={handleMouseDownPassword}
edge="end"
size="large"
>
{showPassword ? <Visibility /> : <VisibilityOff />}
</IconButton>
</InputAdornment>
}
}}
>
{groupOptions.map((option) => {
return (
<MenuItem key={option} value={option}>
{option}
</MenuItem>
);
})}
</Select>
{touched.group && errors.group && (
<FormHelperText error id="helper-tex-channel-group-label">
{errors.group}
</FormHelperText>
aria-describedby="helper-text-channel-password-label"
/>
{touched.password && errors.password && (
<FormHelperText error id="helper-tex-channel-password-label">
{errors.password}
</FormHelperText>
)}
</FormControl>
</>
)}
<DialogActions>
<Button onClick={onCancel}>取消</Button>
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
提交
</Button>
</DialogActions>
</form>
)}
</Formik>
</DialogContent>
</Dialog>
{values.is_edit && (
<>
<FormControl fullWidth error={Boolean(touched.quota && errors.quota)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-quota-label">额度</InputLabel>
<OutlinedInput
id="channel-quota-label"
label="额度"
type="number"
value={values.quota}
name="quota"
endAdornment={<InputAdornment position="end">{renderQuotaWithPrompt(values.quota)}</InputAdornment>}
onBlur={handleBlur}
onChange={handleChange}
aria-describedby="helper-text-channel-quota-label"
disabled={values.unlimited_quota}
/>
{touched.quota && errors.quota && (
<FormHelperText error id="helper-tex-channel-quota-label">
{errors.quota}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.group && errors.group)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-group-label">分组</InputLabel>
<Select
id="channel-group-label"
label="分组"
value={values.group}
name="group"
onBlur={handleBlur}
onChange={(e) => {
handleChange(e);
if (e.target.value === 'default') {
setFieldValue('expiration_date', null);
}
}}
MenuProps={{
PaperProps: {
style: {
maxHeight: 200
}
}
}}
>
{groupOptions.map((option) => {
return (
<MenuItem key={option} value={option}>
{option}
</MenuItem>
);
})}
</Select>
{touched.group && errors.group && (
<FormHelperText error id="helper-tex-channel-group-label">
{errors.group}
</FormHelperText>
)}
</FormControl>
</>
)}
{values.group !== 'default' && (
<FormControl fullWidth error={Boolean(touched.expiration_date && errors.expiration_date)} sx={{ ...theme.typography.otherInput }}>
<TextField
id="channel-expiration_date-label"
label="到期时间"
type="date" // 修改为 date
value={values.expiration_date}
name="expiration_date"
onBlur={handleBlur}
onChange={handleChange}
InputLabelProps={{
shrink: true
}}
inputProps={{
max: '9999-12-31' // 设置最大日期
}}
aria-describedby="helper-text-channel-expiration_date-label"
disabled={values.expiration_date === -1}
/>
{touched.expiration_date && errors.expiration_date && (
<FormHelperText error id="helper-tex-channel-expiration_date-label">
{errors.expiration_date}
</FormHelperText>
)}
<FormControlLabel
control={
<Switch
checked={values.expiration_date === -1}
onChange={(e) => setFieldValue('expiration_date', e.target.checked ? -1 : '')}
name="permanent"
color="primary"
/>
}
label="永不过期"
/>
</FormControl>
)}
<DialogActions>
<Button onClick={onCancel}>取消</Button>
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
提交
</Button>
</DialogActions>
</form>
)}
</Formik>
</DialogContent>
</Dialog>
);
};
@@ -285,4 +360,4 @@ EditModal.propTypes = {
userId: PropTypes.number,
onCancel: PropTypes.func,
onOk: PropTypes.func
};
};

View File

@@ -10,12 +10,14 @@ const COPY_OPTIONS = [
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
{ key: 'ama', text: 'BotGem', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
];
const OPEN_LINK_OPTIONS = [
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
{ key: 'ama', text: 'BotGem', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
{ key: 'lobechat', text: 'LobeChat', value: 'lobechat' },
];
function renderTimestamp(timestamp) {
@@ -114,6 +116,9 @@ const TokensTable = () => {
case 'next':
url = nextUrl;
break;
case 'lobechat':
url = nextLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}"/v1"}}}`;
break;
default:
url = `sk-${key}`;
}
@@ -153,7 +158,11 @@ const TokensTable = () => {
case 'opencat':
url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`;
break;
case 'lobechat':
url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}"/v1"}}}`;
break;
default:
url = defaultUrl;
}

View File

@@ -1,44 +1,46 @@
export const CHANNEL_OPTIONS = [
{key: 1, text: 'OpenAI', value: 1, color: 'green'},
{key: 14, text: 'Anthropic Claude', value: 14, color: 'black'},
{key: 33, text: 'AWS', value: 33, color: 'black'},
{key: 3, text: 'Azure OpenAI', value: 3, color: 'olive'},
{key: 11, text: 'Google PaLM2', value: 11, color: 'orange'},
{key: 24, text: 'Google Gemini', value: 24, color: 'orange'},
{key: 28, text: 'Mistral AI', value: 28, color: 'orange'},
{key: 41, text: 'Novita', value: 41, color: 'purple'},
{key: 40, text: '字节跳动豆包', value: 40, color: 'blue'},
{key: 15, text: '百度文心千帆', value: 15, color: 'blue'},
{key: 17, text: '阿里通义千问', value: 17, color: 'orange'},
{key: 18, text: '讯飞星火认知', value: 18, color: 'blue'},
{key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet'},
{key: 19, text: '360 智脑', value: 19, color: 'blue'},
{key: 25, text: 'Moonshot AI', value: 25, color: 'black'},
{key: 23, text: '腾讯混元', value: 23, color: 'teal'},
{key: 26, text: '百川大模型', value: 26, color: 'orange'},
{key: 27, text: 'MiniMax', value: 27, color: 'red'},
{key: 29, text: 'Groq', value: 29, color: 'orange'},
{key: 30, text: 'Ollama', value: 30, color: 'black'},
{key: 31, text: '零一万物', value: 31, color: 'green'},
{key: 32, text: '阶跃星辰', value: 32, color: 'blue'},
{key: 34, text: 'Coze', value: 34, color: 'blue'},
{key: 35, text: 'Cohere', value: 35, color: 'blue'},
{key: 36, text: 'DeepSeek', value: 36, color: 'black'},
{key: 37, text: 'Cloudflare', value: 37, color: 'orange'},
{key: 38, text: 'DeepL', value: 38, color: 'black'},
{key: 39, text: 'together.ai', value: 39, color: 'blue'},
{key: 42, text: 'VertexAI', value: 42, color: 'blue'},
{key: 8, text: '自定义渠道', value: 8, color: 'pink'},
{key: 22, text: '知识库FastGPT', value: 22, color: 'blue'},
{key: 21, text: '知识库AI Proxy', value: 21, color: 'purple'},
{key: 20, text: '代理OpenRouter', value: 20, color: 'black'},
{key: 2, text: '代理API2D', value: 2, color: 'blue'},
{key: 5, text: '代理OpenAI-SB', value: 5, color: 'brown'},
{key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple'},
{key: 10, text: '代理:AI Proxy', value: 10, color: 'purple'},
{key: 4, text: '代理:CloseAI', value: 4, color: 'teal'},
{key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet'},
{key: 9, text: '代理:AI.LS', value: 9, color: 'yellow'},
{key: 12, text: '代理:API2GPT', value: 12, color: 'blue'},
{key: 13, text: '代理AIGC2D', value: 13, color: 'purple'}
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
{ key: 33, text: 'AWS', value: 33, color: 'black' },
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
{ key: 24, text: 'Google Gemini', value: 24, color: 'orange' },
{ key: 28, text: 'Mistral AI', value: 28, color: 'orange' },
{ key: 41, text: 'Novita', value: 41, color: 'purple' },
{ key: 40, text: '字节跳动豆包', value: 40, color: 'blue' },
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
{ key: 25, text: 'Moonshot AI', value: 25, color: 'black' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal' },
{ key: 26, text: '百川大模型', value: 26, color: 'orange' },
{ key: 27, text: 'MiniMax', value: 27, color: 'red' },
{ key: 29, text: 'Groq', value: 29, color: 'orange' },
{ key: 30, text: 'Ollama', value: 30, color: 'black' },
{ key: 31, text: '零一万物', value: 31, color: 'green' },
{ key: 32, text: '阶跃星辰', value: 32, color: 'blue' },
{ key: 34, text: 'Coze', value: 34, color: 'blue' },
{ key: 35, text: 'Cohere', value: 35, color: 'blue' },
{ key: 36, text: 'DeepSeek', value: 36, color: 'black' },
{ key: 37, text: 'Cloudflare', value: 37, color: 'orange' },
{ key: 38, text: 'DeepL', value: 38, color: 'black' },
{ key: 39, text: 'together.ai', value: 39, color: 'blue' },
{ key: 42, text: 'VertexAI', value: 42, color: 'blue' },
{ key: 43, text: 'Proxy', value: 43, color: 'blue' },
{ key: 44, text: 'SiliconFlow', value: 44, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },
{ key: 20, text: '代理OpenRouter', value: 20, color: 'black' },
{ key: 2, text: '代理:API2D', value: 2, color: 'blue' },
{ key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
{ key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' },
{ key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' },
{ key: 4, text: '代理:CloseAI', value: 4, color: 'teal' },
{ key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' },
{ key: 9, text: '代理AI.LS', value: 9, color: 'yellow' },
{ key: 12, text: '代理API2GPT', value: 12, color: 'blue' },
{ key: 13, text: '代理AIGC2D', value: 13, color: 'purple' }
];

View File

@@ -170,7 +170,7 @@ const EditChannel = () => {
showInfo('请填写渠道名称和渠道密钥!');
return;
}
if (inputs.models.length === 0) {
if (inputs.type !== 43 && inputs.models.length === 0) {
showInfo('请至少选择一个模型!');
return;
}
@@ -370,63 +370,75 @@ const EditChannel = () => {
</Message>
)
}
<Form.Field>
<Form.Dropdown
label='模型'
placeholder={'请选择该渠道所支持的模型'}
name='models'
required
fluid
multiple
search
onLabelClick={(e, { value }) => {
copy(value).then();
}}
selection
onChange={handleInputChange}
value={inputs.models}
autoComplete='new-password'
options={modelOptions}
/>
</Form.Field>
<div style={{ lineHeight: '40px', marginBottom: '12px' }}>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: basicModels });
}}>填入相关模型</Button>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: fullModels });
}}>填入所有模型</Button>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: [] });
}}>清除所有模型</Button>
<Input
action={
<Button type={'button'} onClick={addCustomModel}>填入</Button>
}
placeholder='输入自定义模型名称'
value={customModel}
onChange={(e, { value }) => {
setCustomModel(value);
}}
onKeyDown={(e) => {
if (e.key === 'Enter') {
addCustomModel();
e.preventDefault();
}
}}
/>
</div>
<Form.Field>
<Form.TextArea
label='模型重定向'
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
name='model_mapping'
onChange={handleInputChange}
value={inputs.model_mapping}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
</Form.Field>
{
inputs.type !== 43 && (
<Form.Field>
<Form.Dropdown
label='模型'
placeholder={'请选择该渠道所支持的模型'}
name='models'
required
fluid
multiple
search
onLabelClick={(e, { value }) => {
copy(value).then();
}}
selection
onChange={handleInputChange}
value={inputs.models}
autoComplete='new-password'
options={modelOptions}
/>
</Form.Field>
)
}
{
inputs.type !== 43 && (
<div style={{ lineHeight: '40px', marginBottom: '12px' }}>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: basicModels });
}}>填入相关模型</Button>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: fullModels });
}}>填入所有模型</Button>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: [] });
}}>清除所有模型</Button>
<Input
action={
<Button type={'button'} onClick={addCustomModel}>填入</Button>
}
placeholder='输入自定义模型名称'
value={customModel}
onChange={(e, { value }) => {
setCustomModel(value);
}}
onKeyDown={(e) => {
if (e.key === 'Enter') {
addCustomModel();
e.preventDefault();
}
}}
/>
</div>
)
}
{
inputs.type !== 43 && (
<Form.Field>
<Form.TextArea
label='模型重定向'
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
name='model_mapping'
onChange={handleInputChange}
value={inputs.model_mapping}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
</Form.Field>
)
}
{
inputs.type === 33 && (
<Form.Field>