Compare commits

..

35 Commits

Author SHA1 Message Date
ckt1031
2c62053933 feat: better translation 2023-07-28 15:31:31 +08:00
ckt1031
442673a4be feat(ci): add english binary file 2023-07-28 14:59:34 +08:00
ckt1031
bb17e09818 Merge branch 'support-google-oauth' into refactor-main 2023-07-28 14:57:24 +08:00
ckt1031
80288492f9 Merge branch 'support-discord-oauth' into refactor-main 2023-07-28 14:57:11 +08:00
ckt1031
d89e13dac3 fix: use server address as redirect url 2023-07-28 14:56:17 +08:00
ckt1031
0c70181f81 fix: use server address as redirect url 2023-07-28 14:55:40 +08:00
ckt
3893d4bdf1 Merge branch 'songquanpeng:main' into refactor-main 2023-07-28 00:31:21 +08:00
ckt1031
4aa9f97017 fix: bump pgx 2023-07-27 14:55:45 +08:00
ckt1031
ab041d94e2 fix: go deps 2023-07-27 14:55:14 +08:00
ckt1031
18655fed08 Merge branch 'support-postgres-sql' into refactor-main 2023-07-27 14:51:23 +08:00
ckt1031
bba49c959e feat: support postgres 2023-07-27 14:47:45 +08:00
ckt1031
22f4419b85 fix: reject for wrong status code 2023-07-27 11:10:32 +08:00
ckt1031
c7c3b9d326 fix: backstick for postgres 2023-07-27 11:07:08 +08:00
ckt1031
0d6163a9fb feat: support PG-SQL 2023-07-27 00:10:08 +08:00
ckt1031
7419ca511e chore: use own docker image 2023-07-26 11:45:31 +08:00
ckt
ccd7a99b68 Merge branch 'songquanpeng:main' into refactor-main 2023-07-25 22:38:59 +08:00
ckt1031
759423d69e Merge branch 'support-discord-oauth' into refactor-main 2023-07-25 12:22:12 +08:00
ckt1031
04110d4b01 Merge branch 'support-google-oauth' into refactor-main 2023-07-25 12:22:04 +08:00
ckt1031
d6b2131720 fix: discord 2023-07-25 12:21:37 +08:00
ckt1031
438daea433 fix: google id 2023-07-25 12:14:41 +08:00
ckt1031
c6c070b8bd Merge branch 'channel-stream-mode' into refactor-main 2023-07-24 22:05:54 +08:00
ckt1031
13b3bfee2a fix: channel issue 2023-07-24 22:05:21 +08:00
ckt1031
2b42b4f364 feat: support chatgpt next web 2023-07-24 21:49:04 +08:00
ckt1031
f37e41eb1d Merge branch 'support-google-oauth' into refactor-main 2023-07-24 21:31:45 +08:00
ckt1031
c144c64fff feat: support Google OAuth 2023-07-24 20:09:52 +08:00
ckt1031
8956e2fd60 Merge branch 'channel-stream-mode' into refactor-main 2023-07-24 18:12:53 +08:00
ckt1031
30187cebe8 fix: use int instead of bool 2023-07-24 18:12:16 +08:00
ckt1031
00d3a78bef Merge branch 'channel-stream-mode' into refactor-main 2023-07-24 15:35:16 +08:00
ckt1031
a588241515 feat: allow toggling stream mode of channels 2023-07-24 15:30:08 +08:00
ckt1031
546f9e1db5 Merge branch 'support-discord-oauth' into refactor-main 2023-07-24 12:22:14 +08:00
ckt1031
4908a9eddc feat: add Discord OAuth 2023-07-24 12:20:52 +08:00
ckt1031
15cdaee762 feat: bump go deps 2023-07-24 11:38:09 +08:00
ckt1031
395ee121ed feat: bump all node deps 2023-07-24 11:27:46 +08:00
ckt1031
4cea6279ab fix: install @babel/plugin-proposal-private-property-in-object 2023-07-24 11:19:55 +08:00
ckt1031
f50683e75f feat: add turnstile in login form 2023-07-24 10:54:04 +08:00
63 changed files with 1656 additions and 1769 deletions

View File

@@ -38,7 +38,7 @@ jobs:
uses: docker/metadata-action@v4
with:
images: |
justsong/one-api-en
ckt1031/one-api-en
- name: Build and push Docker images
uses: docker/build-push-action@v3

View File

@@ -42,7 +42,7 @@ jobs:
uses: docker/metadata-action@v4
with:
images: |
justsong/one-api
ckt1031/one-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images

View File

@@ -49,7 +49,7 @@ jobs:
uses: docker/metadata-action@v4
with:
images: |
justsong/one-api
ckt1031/one-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images

59
.github/workflows/linux-release-en.yml vendored Normal file
View File

@@ -0,0 +1,59 @@
name: Linux Release (English)
permissions:
contents: write
on:
push:
tags:
- "*"
- "!*-alpha*"
jobs:
release:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Translate
run: |
python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
- uses: actions/setup-node@v3
with:
node-version: 16
- name: Build Frontend
env:
CI: ""
run: |
cd web
npm install
REACT_APP_VERSION=$(git describe --tags) npm run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: ">=1.18.0"
- name: Build Backend (amd64)
run: |
go mod download
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-en
- name: Build Backend (arm64)
run: |
sudo apt-get update
sudo apt-get install gcc-aarch64-linux-gnu
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64-en
- name: Release
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')
with:
files: |
one-api-en
one-api-arm64-en
draft: true
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -4,7 +4,7 @@ WORKDIR /build
COPY ./web .
COPY ./VERSION .
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
RUN REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2

View File

@@ -57,13 +57,15 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use
> **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability.
## Features
1. Support for multiple large models:
+ [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
+ [x] [Anthropic Claude Series Models](https://anthropic.com)
+ [x] [Google PaLM2 Series Models](https://developers.generativeai.google)
+ [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn)
1. Supports multiple API access channels:
+ [x] Official OpenAI channel (support proxy configuration)
+ [x] **Azure OpenAI API**
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
+ [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [API2D](https://api2d.com/r/197971)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`)
+ [x] Custom channel: Various third-party proxy services not included in the list
2. Supports access to multiple channels through **load balancing**.
3. Supports **stream mode** that enables typewriter-like effect through stream transmission.
4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details.
@@ -137,7 +139,7 @@ The initial account username is `root` and password is `123456`.
cd one-api/web
npm install
npm run build
# Build the backend
cd ..
go mod download
@@ -173,12 +175,7 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
<summary><strong>Deploy on Sealos</strong></summary>
<div>
> Sealos supports high concurrency, dynamic scaling, and stable operations for millions of users.
> Click the button below to deploy with one click.👇
[![](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
Please refer to [this tutorial](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md).
</div>
</details>
@@ -190,7 +187,7 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
> Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage.
1. First, fork the code.
2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console.
2. Go to [Zeabur](https://zeabur.com/), log in, and enter the console.
3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port).
4. Copy the connection parameters and run ```create database `one-api` ``` to create the database.
5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository.
@@ -283,7 +280,7 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Double-check that your interface address and API Key are correct.
## Related Projects
[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM
[FastGPT](https://github.com/c121914yu/FastGPT): Build an AI knowledge base in three minutes
## Note
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.

View File

@@ -63,10 +63,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
+ [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
2. 支持配置镜像以及众多第三方代理服务:
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
+ [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [API2D](https://api2d.com/r/197971)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
@@ -94,7 +93,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
19. 支持通过系统访问令牌访问管理 API。
20. 支持 Cloudflare Turnstile 用户校验。
21. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ 邮箱登录注册以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
@@ -102,16 +101,16 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
### 基于 Docker 进行部署
部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api`
其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
如果你的并发量较大,推荐设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR`
`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
Nginx 的参考配置:
```
server{
@@ -153,7 +152,7 @@ sudo service nginx restart
cd one-api/web
npm install
npm run build
# 构建后端
cd ..
go mod download
@@ -211,11 +210,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
<summary><strong>部署到 Sealos </strong></summary>
<div>
> Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩
> Sealos 可视化部署,仅需 1 分钟
点击以下按钮一键部署(部署后访问出现 404 请等待 3~5 分钟):
[![Deploy-on-Sealos.svg](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
参考这个[教程](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md)中 1~5 步。
</div>
</details>
@@ -227,7 +224,7 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。
1. 首先 fork 一份代码。
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
2. 进入 [Zeabur](https://zeabur.com/),登录,进入控制台。
3. 新建一个 Project在 Service -> Add Service 选择 Marketplace选择 MySQL并记下连接参数用户名、密码、地址、端口
4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。
5. 然后在 Service -> Add Service选择 Git第一次使用需要先授权选择你 fork 的仓库。
@@ -276,18 +273,11 @@ graph LR
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
+ 例子:`SESSION_SECRET=random_string`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 或 PostgreSQL
+ 例子:
+ MySQL`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
+ PostgreSQL`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 8.0 版本
+ 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
+ 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。
+ 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。
+ 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。
+ 请根据你的数据库配置修改下列参数(或者保持默认值):
+ `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。
+ `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。
+ 如果报错 `Error 1040: Too many connections`,请适当减小该值。
+ `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。
@@ -323,7 +313,6 @@ https://openai.justsong.cn
+ 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率)
+ 其中补全倍率对于 GPT3.5 固定为 1.33GPT4 为 2与官方保持一致。
+ 如果是非流模式,官方接口会返回消耗的总 token但是你要注意提示和补全的消耗倍率不一样。
+ 注意One API 的默认倍率就是官方倍率,是已经调整过的。
2. 账户额度足够为什么提示额度不足?
+ 请检查你的令牌额度是否足够,这个和账户额度是分开的。
+ 令牌额度仅供用户设置最大使用量,用户可自由设置。
@@ -340,8 +329,7 @@ https://openai.justsong.cn
+ 上游通道 429 了。
## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用
[FastGPT](https://github.com/c121914yu/FastGPT): 三分钟搭建 AI 知识库
## 注意

View File

@@ -22,6 +22,7 @@ var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var UsingSQLite = false
var UsingPostgreSQL = false
// Any options with "Secret", "Token" in its key won't be return by GetOptions
@@ -38,25 +39,12 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var DiscordOAuthEnabled = false
var WeChatAuthEnabled = false
var GoogleOAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
@@ -68,10 +56,16 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var DiscordClientId = ""
var DiscordClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var GoogleClientId = ""
var GoogleClientSecret = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
@@ -153,6 +147,16 @@ const (
ChannelStatusDisabled = 2 // also don't use 0
)
const (
ChannelAllowNonStreamEnabled = 1
ChannelAllowNonStreamDisabled = 2
)
const (
ChannelAllowStreamEnabled = 1
ChannelAllowStreamDisabled = 2
)
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
@@ -171,28 +175,24 @@ const (
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
}

View File

@@ -42,14 +42,10 @@ var ModelRatio = map[string]float64{
"claude-2": 30,
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
"qwen-plus-v1": 0.5715, // Same as above
"SparkDesk": 0.8572, // TBD
}
func ModelRatio2JSONString() string {

View File

@@ -7,7 +7,6 @@ import (
"log"
"math/rand"
"net"
"os"
"os/exec"
"runtime"
"strconv"
@@ -178,15 +177,3 @@ func Max(a int, b int) int {
return b
}
}
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}

View File

@@ -1,17 +1,20 @@
package controller
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
@@ -23,8 +26,6 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
case common.ChannelTypeBaidu:
fallthrough
case common.ChannelTypeZhipu:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo"
@@ -60,21 +61,86 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
return err, nil
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err, nil
if resp.StatusCode != http.StatusOK {
return errors.New(fmt.Sprintf("status code %d", resp.StatusCode)), nil
}
if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if channel.AllowStreaming == common.ChannelAllowStreamEnabled && isStream {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
// ChatGPT Next Web
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
// Remove event: event in the front or back
data = strings.TrimPrefix(data, "event: event")
data = strings.TrimSuffix(data, "event: event")
// Remove everything, only keep `data: {...}` <--- this is the json
// Find the start and end indices of `data: {...}` substring
startIndex := strings.Index(data, "data:")
endIndex := strings.LastIndex(data, "}")
// If both indices are found and end index is greater than start index
if startIndex != -1 && endIndex != -1 && endIndex > startIndex {
// Extract the `data: {...}` substring
data = data[startIndex : endIndex+1]
}
}
if !strings.HasPrefix(data, "data:") {
continue
}
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
return err, nil
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
}
}
}
if responseText == "" {
return errors.New("Empty response"), nil
}
} else if channel.AllowNonStreaming == common.ChannelAllowNonStreamEnabled {
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err, nil
}
if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
}
return nil, nil
}
func buildTestRequest() *ChatRequest {
func buildTestRequest(stream bool) *ChatRequest {
testRequest := &ChatRequest{
Model: "", // this will be set later
MaxTokens: 1,
Stream: stream,
}
testMessage := Message{
Role: "user",
@@ -101,7 +167,7 @@ func TestChannel(c *gin.Context) {
})
return
}
testRequest := buildTestRequest()
testRequest := buildTestRequest(channel.AllowStreaming == common.ChannelAllowStreamEnabled)
tik := time.Now()
err, _ = testChannel(channel, *testRequest)
tok := time.Now()
@@ -156,7 +222,6 @@ func testAllChannels(notify bool) error {
if err != nil {
return err
}
testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
@@ -167,6 +232,7 @@ func testAllChannels(notify bool) error {
continue
}
tik := time.Now()
testRequest := buildTestRequest(channel.AllowStreaming == common.ChannelAllowStreamEnabled)
err, openaiErr := testChannel(channel, *testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()

View File

@@ -1,12 +1,13 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
func GetAllChannels(c *gin.Context) {

223
controller/discord.go Normal file
View File

@@ -0,0 +1,223 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type DiscordOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type DiscordUser struct {
Id string `json:"id"`
Username string `json:"username"`
}
func getDiscordUserInfoByCode(codeFromURLParamaters string, host string) (*DiscordUser, error) {
if codeFromURLParamaters == "" {
return nil, errors.New("无效参数")
}
RequestClient := &http.Client{}
accessTokenBody := bytes.NewBuffer([]byte(fmt.Sprintf(
"client_id=%s&client_secret=%s&grant_type=authorization_code&redirect_uri=https://%s/oauth/discord&code=%s&scope=identify",
common.DiscordClientId, common.DiscordClientSecret, common.ServerAddress, codeFromURLParamaters,
)))
req, _ := http.NewRequest("POST",
"https://discordapp.com/api/oauth2/token",
accessTokenBody,
)
req.Header = http.Header{
"Content-Type": []string{"application/x-www-form-urlencoded"},
"Accept": []string{"application/json"},
}
resp, err := RequestClient.Do(req)
if resp.StatusCode != 200 || err != nil {
return nil, errors.New("访问令牌无效")
}
var discordOAuthResponse DiscordOAuthResponse
json.NewDecoder(resp.Body).Decode(&discordOAuthResponse)
accessToken := fmt.Sprintf("Bearer %s", discordOAuthResponse.AccessToken)
// Get User Info
req, _ = http.NewRequest("GET", "https://discord.com/api/users/@me", nil)
req.Header = http.Header{
"Content-Type": []string{"application/json"},
"Authorization": []string{accessToken},
}
defer resp.Body.Close()
resp, err = RequestClient.Do(req)
if resp.StatusCode != 200 || err != nil {
return nil, errors.New("Discord 用户信息无效")
}
var discordUser DiscordUser
json.NewDecoder(resp.Body).Decode(&discordUser)
if err != nil {
return nil, err
}
if discordUser.Id == "" {
return nil, errors.New("返回值无效,用户字段为空,请稍后再试!")
}
defer resp.Body.Close()
return &discordUser, nil
}
func DiscordOAuth(c *gin.Context) {
session := sessions.Default(c)
username := session.Get("username")
if username != nil {
DiscordBind(c)
return
}
if !common.DiscordOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code, c.Request.Host)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
DiscordId: discordUser.Id,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
err := user.FillUserByDiscordId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
if discordUser.Username != "" {
user.DisplayName = discordUser.Username
} else {
user.DisplayName = "Discord User"
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func DiscordBind(c *gin.Context) {
if !common.DiscordOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code, c.Request.Host)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
DiscordId: discordUser.Id,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Discord 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.DiscordId = discordUser.Id
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

226
controller/google.go Normal file
View File

@@ -0,0 +1,226 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type GoogleAccessTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
}
type GoogleUser struct {
Sub string `json:"sub"`
Name string `json:"name"`
}
func getGoogleUserInfoByCode(codeFromURLParamaters string, host string) (*GoogleUser, error) {
if codeFromURLParamaters == "" {
return nil, errors.New("无效参数")
}
RequestClient := &http.Client{}
accessTokenBody := bytes.NewBuffer([]byte(fmt.Sprintf(
"code=%s&client_id=%s&client_secret=%s&redirect_uri=%s/oauth/google&grant_type=authorization_code",
codeFromURLParamaters, common.GoogleClientId, common.GoogleClientSecret, common.ServerAddress,
)))
req, _ := http.NewRequest("POST",
"https://oauth2.googleapis.com/token",
accessTokenBody,
)
req.Header = http.Header{
"Content-Type": []string{"application/x-www-form-urlencoded"},
"Accept": []string{"application/json"},
}
resp, err := RequestClient.Do(req)
if resp.StatusCode != 200 || err != nil {
return nil, errors.New("访问令牌无效")
}
var googleTokenResponse GoogleAccessTokenResponse
json.NewDecoder(resp.Body).Decode(&googleTokenResponse)
accessToken := "Bearer " + googleTokenResponse.AccessToken
// Get User Info
req, _ = http.NewRequest("GET", "https://www.googleapis.com/oauth2/v3/userinfo", nil)
req.Header = http.Header{
"Content-Type": []string{"application/json"},
"Authorization": []string{accessToken},
}
defer resp.Body.Close()
resp, err = RequestClient.Do(req)
if resp.StatusCode != 200 || err != nil {
return nil, errors.New("Google 用户信息无效")
}
var googleUser GoogleUser
// Parse json to googleUser
err = json.NewDecoder(resp.Body).Decode(&googleUser)
if err != nil {
return nil, err
}
if googleUser.Sub == "" {
return nil, errors.New("返回值无效,用户字段为空,请稍后再试!")
}
defer resp.Body.Close()
return &googleUser, nil
}
func GoogleOAuth(c *gin.Context) {
session := sessions.Default(c)
username := session.Get("username")
if username != nil {
GoogleBind(c)
return
}
if !common.GoogleOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Google 登录以及注册",
})
return
}
code := c.Query("code")
googleUser, err := getGoogleUserInfoByCode(code, c.Request.Host)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
GoogleId: googleUser.Sub,
}
if model.IsGoogleIdAlreadyTaken(user.GoogleId) {
err := user.FillUserByGoogleId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "google_" + strconv.Itoa(model.GetMaxUserId()+1)
if googleUser.Name != "" {
user.DisplayName = googleUser.Name
} else {
user.DisplayName = "Google User"
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func GoogleBind(c *gin.Context) {
if !common.GoogleOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Google 登录以及注册",
})
return
}
code := c.Query("code")
googleUser, err := getGoogleUserInfoByCode(code, c.Request.Host)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
GoogleId: googleUser.Sub,
}
if model.IsGoogleIdAlreadyTaken(user.GoogleId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Google 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.GoogleId = googleUser.Sub
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -6,7 +6,6 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
@@ -21,6 +20,10 @@ func GetStatus(c *gin.Context) {
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"discord_oauth": common.DiscordOAuthEnabled,
"discord_client_id": common.DiscordClientId,
"google_oauth": common.GoogleOAuthEnabled,
"google_client_id": common.GoogleClientId,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
@@ -80,22 +83,6 @@ func SendEmailVerification(c *gin.Context) {
})
return
}
if common.EmailDomainRestrictionEnabled {
allowed := false
for _, domain := range common.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) {
allowed = true
break
}
}
if !allowed {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
})
return
}
}
if model.IsEmailAlreadyTaken(email) {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -288,15 +288,6 @@ func init() {
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "Embedding-V1",
Parent: nil,
},
{
Id: "PaLM-2",
Object: "model",
@@ -333,33 +324,6 @@ func init() {
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "qwen-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-v1",
Parent: nil,
},
{
Id: "qwen-plus-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",
Created: 1677649963,
OwnedBy: "xunfei",
Permission: permission,
Root: "SparkDesk",
Parent: nil,
},
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {

View File

@@ -50,11 +50,11 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
case "DiscordOAuthEnabled":
if option.Value == "true" && common.DiscordClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名",
"message": "无法启用 Discord OAuth请先填入 Discord Client ID 以及 Discord Client Secret",
})
return
}
@@ -66,6 +66,14 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "GoogleOAuthEnabled":
if option.Value == "true" && common.GoogleClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Google OAuth请先填入 Google Client ID 以及 Google Client Secret",
})
return
}
case "TurnstileCheckEnabled":
if option.Value == "true" && common.TurnstileSiteKey == "" {
c.JSON(http.StatusOK, gin.H{

View File

@@ -1,239 +0,0 @@
package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.Content,
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.Content
break
}
messages = append(messages, AliMessage{
User: message.Content,
Bot: request.Messages[i+1].Content,
})
i++
}
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
Prompt: prompt,
History: messages,
},
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP,
// TopK: 50,
// //Seed: 0,
// //EnableSearch: false,
//},
}
}
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
usage.PromptTokens += aliResponse.Usage.InputTokens
usage.CompletionTokens += aliResponse.Usage.OutputTokens
usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -3,22 +3,22 @@ package controller
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
"sync"
"time"
)
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
type BaiduTokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
SessionKey string `json:"session_key"`
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
SessionSecret string `json:"session_secret"`
}
type BaiduMessage struct {
@@ -54,35 +54,6 @@ type BaiduChatStreamResponse struct {
IsEnd bool `json:"is_end"`
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages {
@@ -130,9 +101,7 @@ func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &stopFinishReason
}
choice.FinishReason = "stop"
response := ChatCompletionsStreamResponse{
Id: baiduResponse.Id,
Object: "chat.completion.chunk",
@@ -143,40 +112,6 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
return &response
}
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
baiduEmbeddingRequest := BaiduEmbeddingRequest{
Input: nil,
}
switch request.Input.(type) {
case string:
baiduEmbeddingRequest.Input = []string{request.Input.(string)}
case []any:
for _, item := range request.Input.([]any) {
if str, ok := item.(string); ok {
baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str)
}
}
}
return &baiduEmbeddingRequest
}
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
@@ -205,7 +140,11 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -273,96 +212,3 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var baiduResponse BaiduEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func getBaiduAccessToken(apiKey string) (string, error) {
if val, ok := baiduTokenStore.Load(apiKey); ok {
var accessToken BaiduAccessToken
if accessToken, ok = val.(BaiduAccessToken); ok {
// soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go func() {
_, _ = getBaiduAccessTokenHelper(apiKey)
}()
}
return accessToken.AccessToken, nil
}
}
accessToken, err := getBaiduAccessTokenHelper(apiKey)
if err != nil {
return "", err
}
if accessToken == nil {
return "", errors.New("getBaiduAccessToken return a nil token")
}
return (*accessToken).AccessToken, nil
}
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
}
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
parts[0], parts[1]), nil)
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := impatientHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
var accessToken BaiduAccessToken
err = json.NewDecoder(res.Body).Decode(&accessToken)
if err != nil {
return nil, err
}
if accessToken.Error != "" {
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
}
if accessToken.AccessToken == "" {
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
}
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
baiduTokenStore.Store(apiKey, accessToken)
return &accessToken, nil
}

View File

@@ -81,10 +81,7 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason)
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
@@ -141,7 +138,11 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@@ -4,11 +4,12 @@ import (
"bufio"
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
"github.com/gin-gonic/gin"
)
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
@@ -34,7 +35,23 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
// ChatGPT Next Web
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
// Remove event: event in the front or back
data = strings.TrimPrefix(data, "event: event")
data = strings.TrimSuffix(data, "event: event")
// Remove everything, only keep `data: {...}` <--- this is the json
// Find the start and end indices of `data: {...}` substring
startIndex := strings.Index(data, "data:")
endIndex := strings.LastIndex(data, "}")
// If both indices are found and end index is greater than start index
if startIndex != -1 && endIndex != -1 && endIndex > startIndex {
// Extract the `data: {...}` substring
data = data[startIndex : endIndex+1]
}
}
if !strings.HasPrefix(data, "data:") {
continue
}
dataChan <- data
@@ -46,7 +63,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
return
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
@@ -56,7 +73,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
return
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
@@ -66,7 +83,11 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -88,7 +109,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
return nil, responseText
}
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
@@ -128,17 +149,5 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += countTokenText(choice.Message.Content, model)
}
textResponse.Usage = Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
return nil, &textResponse.Usage
}

View File

@@ -94,7 +94,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsS
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
}
choice.FinishReason = &stopFinishReason
choice.FinishReason = "stop"
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "palm2"
@@ -143,7 +143,11 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
dataChan <- string(jsonResponse)
stopChan <- true
}()
setEventStreamHeaders(c)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@@ -5,13 +5,13 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const (
@@ -20,18 +20,12 @@ const (
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
)
var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
httpClient = &http.Client{}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
}
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -79,7 +73,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
@@ -100,10 +94,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@@ -145,16 +135,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
var err error
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
}
fullRequestURL += "?access_token=" + apiKey
fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" {
@@ -169,8 +153,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
}
var promptTokens int
var completionTokens int
@@ -224,20 +206,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu:
var jsonData []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduEmbeddingRequest)
default:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduRequest)
}
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonStr, err := json.Marshal(baiduRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
requestBody = bytes.NewBuffer(jsonStr)
case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest)
@@ -252,116 +226,102 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err := json.Marshal(aliRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
var req *http.Request
var resp *http.Response
isStream := textRequest.Stream
if apiType != APITypeXunfei { // cause xunfei use websocket
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
case APITypeClaude:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
case APITypeClaude:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if resp.StatusCode != http.StatusOK {
return errorWrapper(nil, "bad_status_code", resp.StatusCode)
}
var textResponse TextResponse
tokenName := c.GetString("token_name")
channelId := c.GetInt("channel_id")
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
var streamResponseText string
defer func() {
// c.Writer.Flush()
go func() {
if consumeQuota {
quota := 0
completionRatio := 1.0
if strings.HasPrefix(textRequest.Model, "gpt-3.5") {
completionRatio = 1.333333
}
if strings.HasPrefix(textRequest.Model, "gpt-4") {
completionRatio = 2
}
if consumeQuota {
quota := 0
completionRatio := 1.0
if strings.HasPrefix(textRequest.Model, "gpt-3.5") {
completionRatio = 1.333333
}
if strings.HasPrefix(textRequest.Model, "gpt-4") {
completionRatio = 2
}
if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu {
completionTokens = countTokenText(streamResponseText, textRequest.Model)
} else {
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
if apiType == APITypeZhipu {
// zhipu's API does not return prompt tokens & completion tokens
promptTokens = textResponse.Usage.TotalTokens
}
}
}()
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
switch apiType {
case APITypeOpenAI:
@@ -370,11 +330,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
streamResponseText = responseText
return nil
} else {
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
err, usage := openaiHandler(c, resp, consumeQuota)
if err != nil {
return err
}
@@ -389,8 +348,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
streamResponseText = responseText
return nil
} else {
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
@@ -413,14 +371,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
err, usage := baiduHandler(c, resp)
if err != nil {
return err
}
@@ -435,8 +386,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
streamResponseText = responseText
return nil
} else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
@@ -457,8 +407,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
} else {
err, usage := zhipuHandler(c, resp)
@@ -468,49 +416,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
}
case APITypeAli:
if isStream {
err, usage := aliStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := aliHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeXunfei:
if isStream {
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}

View File

@@ -2,13 +2,10 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"one-api/common"
)
var stopFinishReason = "stop"
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
@@ -107,11 +104,3 @@ func shouldDisableChannel(err *OpenAIError) bool {
}
return false
}
func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

View File

@@ -1,277 +0,0 @@
package controller
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"net/http"
"net/url"
"one-api/common"
"strings"
"time"
)
// https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html
type XunfeiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type XunfeiChatRequest struct {
Header struct {
AppId string `json:"app_id"`
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Message struct {
Text []XunfeiMessage `json:"text"`
} `json:"message"`
} `json:"payload"`
}
type XunfeiChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
type XunfeiChatResponse struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []XunfeiChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Content: message.Content,
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
Content: message.Content,
})
}
}
xunfeiRequest := XunfeiChatRequest{}
xunfeiRequest.Header.AppId = xunfeiAppId
xunfeiRequest.Parameter.Chat.Domain = "general"
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
return &xunfeiRequest
}
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
}
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: response.Payload.Usage.Text,
}
return &fullTextResponse
}
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &stopFinishReason
}
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "SparkDesk",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
HmacWithShaToBase64 := func(algorithm, data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
ul, err := url.Parse(hostUrl)
if err != nil {
fmt.Println(err)
}
date := time.Now().UTC().Format(time.RFC1123)
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
sign := strings.Join(signString, "\n")
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
callUrl := hostUrl + "?" + v.Encode()
return callUrl
}
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
hostUrl := "wss://aichat.xf-yun.com/v1/chat"
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
if err != nil || resp.StatusCode != 101 {
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
}
data := requestOpenAI2Xunfei(textRequest, appId)
err = conn.WriteJSON(data)
if err != nil {
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
}
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
common.SysError("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
common.SysError("error closing websocket connection: " + err.Error())
}
break
}
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
}
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var xunfeiResponse XunfeiChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &xunfeiResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if xunfeiResponse.Header.Code != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: xunfeiResponse.Header.Message,
Type: "xunfei_error",
Param: "",
Code: xunfeiResponse.Header.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -163,6 +163,7 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse
choice.FinishReason = ""
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
@@ -175,7 +176,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResp
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = ""
choice.FinishReason = &stopFinishReason
choice.FinishReason = "stop"
response := ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId,
Object: "chat.completion.chunk",
@@ -193,8 +194,8 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
return i + 2, data[0:i], nil
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
@@ -207,24 +208,23 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
go func() {
for scanner.Scan() {
data := scanner.Text()
lines := strings.Split(data, "\n")
for i, line := range lines {
if len(line) < 5 {
continue
}
if line[:5] == "data:" {
dataChan <- line[5:]
if i != len(lines)-1 {
dataChan <- "\n"
}
} else if line[:5] == "meta:" {
metaChan <- line[5:]
}
data = strings.Trim(data, "\"")
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] == "data:" {
dataChan <- data[5:]
} else if data[:5] == "meta:" {
metaChan <- data[5:]
}
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:

View File

@@ -46,6 +46,7 @@ type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
}
type TextRequest struct {
@@ -81,9 +82,8 @@ type OpenAIErrorWithStatusCode struct {
}
type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type OpenAITextResponseChoice struct {
@@ -100,19 +100,6 @@ type OpenAITextResponse struct {
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
@@ -124,7 +111,7 @@ type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
FinishReason string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
@@ -176,7 +163,7 @@ func Relay(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
@@ -207,10 +194,10 @@ func RelayNotImplemented(c *gin.Context) {
func RelayNotFound(c *gin.Context) {
err := OpenAIError{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Message: fmt.Sprintf("API not found: %s:%s", c.Request.Method, c.Request.URL.Path),
Type: "one_api_error",
Param: "",
Code: "",
Code: "api_not_found",
}
c.JSON(http.StatusNotFound, gin.H{
"error": err,

View File

@@ -109,10 +109,10 @@ func AddToken(c *gin.Context) {
})
return
}
if len(token.Name) > 30 {
if len(token.Name) == 0 || len(token.Name) > 20 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称长",
"message": "令牌名称长度必须在1-20之间",
})
return
}
@@ -171,13 +171,6 @@ func UpdateToken(c *gin.Context) {
})
return
}
if len(token.Name) > 30 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称过长",
})
return
}
cleanToken, err := model.GetTokenByIds(token.Id, userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{

View File

@@ -2,7 +2,7 @@ version: '3.4'
services:
one-api:
image: justsong/one-api:latest
image: ckt1031/one-api:latest
container_name: one-api
restart: always
command: --log-dir /app/logs

42
go.mod
View File

@@ -9,21 +9,26 @@ require (
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.14.0
github.com/go-playground/validator/v10 v10.14.1
github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
golang.org/x/crypto v0.9.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.25.0
golang.org/x/crypto v0.11.0
gorm.io/driver/mysql v1.5.1
gorm.io/driver/sqlite v1.5.2
gorm.io/gorm v1.25.2
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.2 // indirect
)
require (
github.com/bytedance/sonic v1.9.2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
@@ -31,31 +36,28 @@ require (
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/go-sql-driver/mysql v1.7.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pelletier/go-toml/v2 v2.0.9 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
golang.org/x/arch v0.4.0 // indirect
golang.org/x/net v0.12.0 // indirect
golang.org/x/sys v0.10.0 // indirect
golang.org/x/text v0.11.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/driver/postgres v1.5.2 // indirect
gorm.io/driver/postgres v1.5.2
)

77
go.sum
View File

@@ -1,8 +1,8 @@
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/bytedance/sonic v1.9.2 h1:GDaNjuWSGu09guE9Oql0MSTNhNCLlWwO8y/xM5BzcbM=
github.com/bytedance/sonic v1.9.2/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -43,12 +43,13 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.14.1 h1:9c50NUPC30zyuKprjL3vNZ0m5oG+jU0zvx4AqHGnv4k=
github.com/go-playground/validator/v10 v10.14.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
@@ -67,25 +68,24 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jackc/pgx/v5 v5.4.2 h1:u1gmGDwbdRUZiwisBm/Ky2M14uQyUP65bG8+20nnyrg=
github.com/jackc/pgx/v5 v5.4.2/go.mod h1:q6iHT8uDNXWiFNOlRqJzBTaSH3+2xCXkokxHZC5qWFY=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg=
github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
@@ -102,7 +102,6 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -115,8 +114,8 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0=
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
@@ -136,8 +135,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
@@ -147,36 +146,36 @@ github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.4.0 h1:A8WCeEWhLwPBKNbFi5Wv5UTCBx5zzubnXDlMOFAzFMc=
golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
@@ -191,15 +190,13 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw=
gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o=
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/driver/sqlite v1.5.2 h1:TpQ+/dqCY4uCigCFyrfnrJnrW9zjpelWVoEVNy5qJkc=
gorm.io/driver/sqlite v1.5.2/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -3,11 +3,6 @@
"%d 点额度": "%d point quota",
"尚未实现": "Not yet implemented",
"余额不足": "Insufficient balance",
"危险操作": "Hazardous operations",
"输入你的账户名": "Enter your account name",
"确认删除": "Confirm Delete",
"确认绑定": "Confirm Binding",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.",
"\"通道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"通道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"测试已在运行中": "Test is already running",
@@ -39,8 +34,8 @@
"兑换码个数必须大于0": "The number of redemption codes must be greater than 0",
"一次兑换码批量生成的个数不能大于 100": "The number of redemption codes generated in a batch cannot be greater than 100",
"通过令牌「%s」使用模型 %s 消耗 %s模型倍率 %.2f,分组倍率 %.2f": "Using model %s with token %s consumes %s (model rate %.2f, group rate %.2f)",
"当前分组上游负载已饱和,请稍后再试": "The current group load is saturated, please try again later",
"令牌名称过长": "Token name is too long",
"当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。": "The current group load is saturated, please try again later, or upgrade your account to improve service quality.",
"令牌名称长度必须在1-20之间": "The length of the token name must be between 1-20",
"令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期": "The token has expired and cannot be enabled. Please modify the expiration time of the token, or set it to never expire.",
"令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度": "The available quota of the token has been used up and cannot be enabled. Please modify the remaining quota of the token, or set it to unlimited quota",
"管理员关闭了密码登录": "The administrator has turned off password login",
@@ -229,7 +224,7 @@
"已是最新版本": "Is the latest version",
"检查更新": "Check for updates",
"公告": "Announcement",
"在此输入新的公告内容,支持 Markdown & HTML 代码": "Enter the new announcement content here, supports Markdown & HTML code",
"在此输入新的公告内容": "Enter new announcement content here",
"保存公告": "Save Announcement",
"个性化设置": "Personalization Settings",
"系统名称": "System Name",
@@ -265,7 +260,7 @@
"注意": "Note",
"此处生成的令牌用于系统管理": "The token generated here is used for system management",
"而非用于请求 OpenAI 相关的服务": "Not for requesting OpenAI related services",
"请知悉": "Please be aware",
"请知悉": "Please be aware.",
"更新个人信息": "Update Personal Information",
"生成系统访问令牌": "Generate System Access Token",
"复制邀请链接": "Copy Invitation Link",
@@ -289,7 +284,7 @@
"兑换时间": "Redemption Time",
"尚未兑换": "Not yet redeemed",
"已复制到剪贴板": "Copied to clipboard",
"无法复制到剪贴板": "Unable to copy to clipboard",
"无法复制到剪贴板": "Unable to copy to clipboard.",
"请手动复制": "Please copy manually",
"已将兑换码填入搜索框": "The voucher code has been filled into the search box",
"复制": "Copy",
@@ -432,7 +427,7 @@
"一分钟后过期": "Expires after one minute",
"创建新的令牌": "Create New Token",
"注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.",
"设为无限额度": "Set to unlimited quota",
"设为无限额度": "Set to unlimited quota",
"更新令牌信息": "Update Token Information",
"请输入充值码!": "Please enter the recharge code!",
"请输入名称": "Please enter a name",
@@ -498,7 +493,6 @@
"参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)",
"模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!",
"取消无限额度": "Cancel unlimited quota",
"取消": "Cancel",
"请输入新的剩余额度": "Please enter the new remaining quota",
"请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code",
"请输入用户名": "Please enter username",
@@ -510,15 +504,72 @@
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
"Homepage URL 填": "Fill in the Homepage URL",
"Authorization callback URL 填": "Fill in the Authorization callback URL",
"请为通道命名": "Please name the channel",
"兑换中": "Redeeming",
"请求失败": "Request failed",
"百度文心千帆": "Baidu Wenxin Qianfan",
"智谱": "Zhipuai",
"代理:": "Agent: ",
"请输入你的账户名以确认删除!": "Please enter your account name to confirm deletion!",
"账户已删除!": "Account deleted!",
"删除个人账户": "Delete personal account",
"重新发送": "Resend ",
"确认删除自己的帐户": "Confirm deletion of your own account",
"输入你的账户名": "Enter your account name",
"以确认删除": "to confirm deletion",
"重试": "Retry",
"无法复制到剪贴板,请手动复制,已将兑换码填入搜索框。": "Unable to copy to clipboard, please copy manually, the redemption code has been filled in the search box.",
"密码重置完成": "Password reset completed",
"流式请求和非流式请求不能同时禁用!": "Streaming requests and non-streaming requests cannot be disabled at the same time!",
"请为渠道命名": "Please name the channel",
"请选择可以使用该渠道的分组": "Please select the group that can use this channel",
"请选择该渠道所支持的模型": "Please select the models supported by this channel",
"填入": "Fill in",
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
"模型重定向": "Model redirection",
"允许流式请求": "Allow streaming requests",
"允许非流式请求": "Allow non-streaming requests",
"请输入 access token当前版本暂不支持自动刷新请每 30 天更新一次": "Please enter the access token, the current version does not support automatic refresh, please update it every 30 days",
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",
"注意,": "Note that, ",
",图片演示。": "related image demo.",
"令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!",
"代理": "Proxy",
"此项可选,用于通过代理站来进行 API 调用请输入代理站地址格式为https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
"取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?",
"按照如下格式输入:": "Enter in the following format:"
"此项可选用于通过Mirror站来进行 API 调用请EnterMirror站地址格式为https://domain.com": "This is optional, used to make API calls through the Mirror site, please enter the Mirror site address, the format is: https://domain.com",
"新密码": "New password",
"新密码已复制到剪贴板:": "New password copied to clipboard: ",
"兑换失败,": "Redemption failed, ",
"当前分组 %s 下对于模型 %s 无可用渠道": "There are no available channels for model %s under the current group %s",
"无权将其他用户权限等级提升到大于等于自己的权限等级": "You are not allowed to raise the permission level of other users to greater than or equal to your own permission level",
"不能删除超级管理员账户": "Cannot delete super administrator account",
"无效参数": "Invalid parameter",
"访问令牌无效": "Access token invalid",
"Google 用户信息无效": "Google user information invalid",
"返回值无效,用户字段为空,请稍后再试!": "The return value is invalid, the user field is empty, please try again later!",
"管理员未开启通过 Google 登录以及注册": "The administrator has not enabled login and registration via Google",
"该 Google 账户已被绑定": "The Google account has been bound",
"点击 <a href='%s'>此处</a> 进行密码重置。": "Click <a href='%s'>here</a> to reset your password.",
"如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s ": "If the link cannot be clicked, please try to click the link below or copy it to the browser to open: <br> %s ",
"该渠道类型当前版本不支持测试,请手动测试": "The current version of the channel type does not support testing, please test manually",
"无法启用 Discord OAuth请先填入 Discord Client ID 以及 Discord Client Secret": "Unable to enable Discord OAuth, please fill in the Discord Client ID and Discord Client Secret first!",
"无法启用 Google OAuth请先填入 Google Client ID 以及 Google Client Secret": "Unable to enable Google OAuth, please fill in the Google Client ID and Google Client Secret first!",
"管理员未开启通过 Discord 登录以及注册": "The administrator has not enabled login and registration via Discord",
"该 Discord 账户已被绑定": "The Discord account has been bound",
"Discord 用户信息无效": "Discord user information invalid",
"绑定 Discord 账号": "Bind Discord account",
"绑定 Google 账号": "Bind Google account",
"已绑定的 Discord 账户": "Bound Discord account",
"已绑定的 Google 账户": "Bound Google account",
"身份验证": "Authentication",
"允许通过 Discord 账户登录和注册": "Allow login and registration via Discord account",
"允许通过 Google 账户登录和注册": "Allow login and registration via Google account",
"配置 Discord OAuth 应用程序": "Configure Discord OAuth application",
"配置 Google OAuth 应用程序": "Configure Google OAuth application",
"用以支持通过 Discord 进行登录注册,": "Used to support login and registration via Discord, ",
"用以支持通过 Google 进行登录注册,": "Used to support login and registration via Google, ",
"管理你的": "Manage your ",
"客户 ID": "Client ID",
"客户秘密": "Client Secret",
"失败重试次数": "Retry times",
"保存": "Save",
"输入您注册的 Discord OAuth APP 的 ID": "Enter the ID of your registered Discord OAuth APP",
"输入您注册的 Google OAuth APP 的 ID": "Enter the ID of your registered Google OAuth APP"
}

View File

@@ -26,9 +26,6 @@ func main() {
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
if common.DebugEnabled {
common.SysLog("running in debug mode")
}
// Initialize SQL Database
err := model.InitDB()
if err != nil {

View File

@@ -2,6 +2,7 @@ package middleware
import (
"fmt"
"log"
"net/http"
"one-api/common"
"one-api/model"
@@ -12,7 +13,8 @@ import (
)
type ModelRequest struct {
Model string `json:"model"`
Model string `json:"model"`
Stream bool `json:"stream" default:"true"`
}
func Distribute() func(c *gin.Context) {
@@ -84,7 +86,8 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "dall-e"
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
log.Print(modelRequest.Stream)
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {

View File

@@ -1,24 +1,47 @@
package model
import (
"fmt"
"one-api/common"
"strings"
)
type Ability struct {
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
AllowStreaming int `json:"allow_streaming" gorm:"default:1"`
AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
ability := Ability{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
cmd := "`group` = ? and model = ? and enabled = 1"
if common.UsingPostgreSQL {
// Make cmd compatible with PostgreSQL
cmd = "\"group\" = ? and model = ? and enabled = true"
}
if stream {
cmd += fmt.Sprintf(" and allow_streaming = %d", common.ChannelAllowStreamEnabled)
} else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled)
}
if common.UsingSQLite || common.UsingPostgreSQL {
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
} else {
cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled)
}
if common.UsingSQLite {
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
} else {
err = DB.Where(cmd, group, model).Order("RAND()").Limit(1).First(&ability).Error
}
if err != nil {
return nil, err
@@ -36,10 +59,12 @@ func (channel *Channel) AddAbilities() error {
for _, model := range models_ {
for _, group := range groups_ {
ability := Ability{
Group: group,
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Group: group,
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
AllowStreaming: channel.AllowStreaming,
AllowNonStreaming: channel.AllowNonStreaming,
}
abilities = append(abilities, ability)
}

View File

@@ -21,13 +21,19 @@ var (
func CacheGetTokenByKey(key string) (*Token, error) {
var token Token
whereItem := "`key` = ?"
if common.UsingPostgreSQL {
// Make cmd compatible with PostgreSQL
whereItem = "\"key\" = ?"
}
if !common.RedisEnabled {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(whereItem, key).First(&token).Error
return &token, err
}
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(whereItem, key).First(&token).Error
if err != nil {
return nil, err
}
@@ -160,9 +166,9 @@ func SyncChannelCache(frequency int) {
}
}
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
if !common.RedisEnabled {
return GetRandomSatisfiedChannel(group, model)
return GetRandomSatisfiedChannel(group, model, stream)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
@@ -170,6 +176,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
idx := rand.Intn(len(channels))
return channels[idx], nil
var filteredChannels []*Channel
for _, channel := range channels {
if (stream && channel.AllowStreaming == common.ChannelAllowStreamEnabled) || (!stream && channel.AllowNonStreaming == common.ChannelAllowNonStreamEnabled) {
filteredChannels = append(filteredChannels, channel)
}
}
idx := rand.Intn(len(filteredChannels))
return filteredChannels[idx], nil
}

View File

@@ -1,8 +1,9 @@
package model
import (
"gorm.io/gorm"
"one-api/common"
"gorm.io/gorm"
)
type Channel struct {
@@ -23,6 +24,8 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
AllowStreaming int `json:"allow_streaming" gorm:"default:1"`
AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -37,7 +40,13 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
}
func SearchChannels(keyword string) (channels []*Channel, err error) {
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
whereItem := "id = ? or name LIKE ? or `key` = ?"
if common.UsingPostgreSQL {
whereItem = "id = ? or name LIKE ? or \"key\" = ?"
}
err = DB.Omit("key").Where(whereItem, keyword, keyword+"%", keyword).Find(&channels).Error
return channels, err
}
@@ -55,7 +64,9 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
func GetRandomChannel() (*Channel, error) {
channel := Channel{}
var err error = nil
if common.UsingSQLite {
if common.UsingPostgreSQL {
err = DB.Where("status = ? and \"group\" = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else if common.UsingSQLite {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error

View File

@@ -1,14 +1,13 @@
package model
import (
"one-api/common"
"os"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"one-api/common"
"os"
"strings"
"time"
)
var DB *gorm.DB
@@ -36,52 +35,41 @@ func createRootAccountIfNeed() error {
return nil
}
func chooseDB() (*gorm.DB, error) {
if os.Getenv("SQL_DSN") != "" {
dsn := os.Getenv("SQL_DSN")
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL
common.SysLog("using MySQL as database")
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
func CountTable(tableName string) (num int64) {
DB.Table(tableName).Count(&num)
return
}
func InitDB() (err error) {
db, err := chooseDB()
var db *gorm.DB
if os.Getenv("POSTGRES_DSN") != "" {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
db, err = gorm.Open(postgres.Open(os.Getenv("POSTGRES_DSN")), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
} else if os.Getenv("SQL_DSN") != "" {
// Use MySQL
common.SysLog("using MySQL as database")
db, err = gorm.Open(mysql.Open(os.Getenv("SQL_DSN")), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
} else {
// Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
db, err = gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
common.SysLog("database connected")
if err == nil {
if common.DebugEnabled {
db = db.Debug()
}
DB = db
sqlDB, err := DB.DB()
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil
}
err = db.AutoMigrate(&Channel{})
err := db.AutoMigrate(&Channel{})
if err != nil {
return err
}

View File

@@ -30,7 +30,9 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["DiscordOAuthEnabled"] = strconv.FormatBool(common.DiscordOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["GoogleOAuthEnabled"] = strconv.FormatBool(common.GoogleOAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
@@ -39,8 +41,6 @@ func InitOptionMap() {
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
@@ -55,9 +55,13 @@ func InitOptionMap() {
common.OptionMap["ServerAddress"] = ""
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["DiscordClientId"] = ""
common.OptionMap["DiscordClientSecret"] = ""
common.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["WeChatServerToken"] = ""
common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
common.OptionMap["GoogleClientId"] = ""
common.OptionMap["GoogleClientSecret"] = ""
common.OptionMap["TurnstileSiteKey"] = ""
common.OptionMap["TurnstileSecretKey"] = ""
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
@@ -137,14 +141,16 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue
case "DiscordOAuthEnabled":
common.DiscordOAuthEnabled = boolValue
case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue
case "GoogleOAuthEnabled":
common.GoogleOAuthEnabled = boolValue
case "TurnstileCheckEnabled":
common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
common.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
case "ApproximateTokenEnabled":
@@ -158,8 +164,6 @@ func updateOptionMap(key string, value string) (err error) {
}
}
switch key {
case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer":
common.SMTPServer = value
case "SMTPPort":
@@ -177,6 +181,10 @@ func updateOptionMap(key string, value string) (err error) {
common.GitHubClientId = value
case "GitHubClientSecret":
common.GitHubClientSecret = value
case "DiscordClientId":
common.DiscordClientId = value
case "DiscordClientSecret":
common.DiscordClientSecret = value
case "Footer":
common.Footer = value
case "SystemName":
@@ -189,6 +197,10 @@ func updateOptionMap(key string, value string) (err error) {
common.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL":
common.WeChatAccountQRCodeImageURL = value
case "GoogleClientId":
common.GoogleClientId = value
case "GoogleClientSecret":
common.GoogleClientSecret = value
case "TurnstileSiteKey":
common.TurnstileSiteKey = value
case "TurnstileSecretKey":

View File

@@ -3,8 +3,9 @@ package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
"gorm.io/gorm"
)
type Redemption struct {
@@ -51,7 +52,14 @@ func Redeem(key string, userId int) (quota int, err error) {
redemption := &Redemption{}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
whereItem := "`key` = ?"
if common.UsingPostgreSQL {
// Make cmd compatible with PostgreSQL
whereItem = "\"key\" = ?"
}
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(whereItem, key).First(redemption).Error
if err != nil {
return errors.New("无效的兑换码")
}

View File

@@ -3,9 +3,10 @@ package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
"strings"
"gorm.io/gorm"
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
@@ -19,7 +20,9 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
GoogleId string `json:"google_id" gorm:"column:google_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
@@ -169,6 +172,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByDiscordId() error {
if user.DiscordId == "" {
return errors.New("Discord id 为空!")
}
DB.Where(User{DiscordId: user.DiscordId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -177,6 +188,14 @@ func (user *User) FillUserByWeChatId() error {
return nil
}
func (user *User) FillUserByGoogleId() error {
if user.GoogleId == "" {
return errors.New("Google id 为空!")
}
DB.Where(User{GoogleId: user.GoogleId}).First(user)
return nil
}
func (user *User) FillUserByUsername() error {
if user.Username == "" {
return errors.New("username 为空!")
@@ -193,6 +212,14 @@ func IsWeChatIdAlreadyTaken(wechatId string) bool {
return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
}
func IsDiscordIdAlreadyTaken(discordId string) bool {
return DB.Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
}
func IsGoogleIdAlreadyTaken(googleId string) bool {
return DB.Where("google_id = ?", googleId).Find(&User{}).RowsAffected == 1
}
func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
@@ -267,7 +294,13 @@ func GetUserEmail(id int) (email string, err error) {
}
func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
selectItem := "`group`"
if common.UsingPostgreSQL {
selectItem = "\"group\""
}
err = DB.Model(&User{}).Where("id = ?", id).Select(selectItem).Find(&group).Error
return group, err
}

View File

@@ -21,14 +21,16 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/google", middleware.CriticalRateLimit(), controller.GoogleOAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
userRoute := apiRouter.Group("/user")
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), controller.Login)
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
userRoute.GET("/logout", controller.Logout)
selfRoute := userRoute.Group("/")
@@ -36,7 +38,7 @@ func SetApiRouter(router *gin.Engine) {
{
selfRoute.GET("/self", controller.GetSelf)
selfRoute.PUT("/self", controller.UpdateSelf)
selfRoute.DELETE("/self", controller.DeleteSelf)
selfRoute.DELETE("/self", middleware.TurnstileCheck(), controller.DeleteSelf)
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp)

View File

@@ -18,7 +18,7 @@ func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
router.Use(middleware.Cache())
router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build")))
router.NoRoute(func(c *gin.Context) {
if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") {
if strings.HasPrefix(c.Request.RequestURI, "/v1") {
controller.RelayNotFound(c)
return
}

View File

@@ -3,18 +3,19 @@
"version": "0.1.0",
"private": true,
"dependencies": {
"axios": "^0.27.2",
"@babel/plugin-proposal-private-property-in-object": "^7.21.11",
"axios": "^1.4.0",
"history": "^5.3.0",
"marked": "^4.1.1",
"marked": "^5.1.1",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
"react-router-dom": "^6.3.0",
"react-router-dom": "^6.14.2",
"react-scripts": "5.0.1",
"react-toastify": "^9.0.8",
"react-turnstile": "^1.0.5",
"react-toastify": "^9.1.3",
"react-turnstile": "^1.1.1",
"semantic-ui-css": "^2.5.0",
"semantic-ui-react": "^2.1.3"
"semantic-ui-react": "^2.1.4"
},
"scripts": {
"start": "react-scripts start",
@@ -41,7 +42,7 @@
]
},
"devDependencies": {
"prettier": "^2.7.1"
"prettier": "^3.0.0"
},
"prettier": {
"singleQuote": true,

View File

@@ -12,6 +12,8 @@ import AddUser from './pages/User/AddUser';
import { API, getLogo, getSystemName, showError, showNotice } from './helpers';
import PasswordResetForm from './components/PasswordResetForm';
import GitHubOAuth from './components/GitHubOAuth';
import DiscordOAuth from './components/DiscordOAuth';
import GoogleOAuth from './components/GoogleOAuth';
import PasswordResetConfirm from './components/PasswordResetConfirm';
import { UserContext } from './context/User';
import { StatusContext } from './context/Status';
@@ -239,6 +241,24 @@ function App() {
</Suspense>
}
/>
<Route
HEAD
path='/oauth/discord'
element={
<Suspense fallback={<Loading></Loading>}>
<DiscordOAuth />
</Suspense>
}
/>
<Route
path='/oauth/google'
element={
<Suspense fallback={<Loading></Loading>}>
<GoogleOAuth />
support-google-oauth
</Suspense>
}
/>
<Route
path='/setting'
element={
@@ -252,11 +272,11 @@ function App() {
<Route
path='/topup'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<TopUp />
</Suspense>
</PrivateRoute>
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<TopUp />
</Suspense>
</PrivateRoute>
}
/>
<Route

View File

@@ -363,12 +363,9 @@ const ChannelsTable = () => {
</Table.Cell>
<Table.Cell>
<Popup
trigger={<span onClick={() => {
updateChannelBalance(channel.id, channel.name, idx);
}} style={{ cursor: 'pointer' }}>
{renderBalance(channel.type, channel.balance)}
</span>}
content='点击更新'
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
key={channel.id}
trigger={renderBalance(channel.type, channel.balance)}
basic
/>
</Table.Cell>
@@ -383,16 +380,16 @@ const ChannelsTable = () => {
>
测试
</Button>
{/*<Button*/}
{/* size={'small'}*/}
{/* positive*/}
{/* loading={updatingBalance}*/}
{/* onClick={() => {*/}
{/* updateChannelBalance(channel.id, channel.name, idx);*/}
{/* }}*/}
{/*>*/}
{/* 更新余额*/}
{/*</Button>*/}
<Button
size={'small'}
positive
loading={updatingBalance}
onClick={() => {
updateChannelBalance(channel.id, channel.name, idx);
}}
>
更新余额
</Button>
<Popup
trigger={
<Button size='small' negative>

View File

@@ -0,0 +1,57 @@
import React, { useContext, useEffect, useState } from 'react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers';
import { UserContext } from '../context/User';
const DiscordOAuth = () => {
const [searchParams, setSearchParams] = useSearchParams();
const [userState, userDispatch] = useContext(UserContext);
const [prompt, setPrompt] = useState('处理中...');
const [processing, setProcessing] = useState(true);
let navigate = useNavigate();
const sendCode = async (code, count) => {
const res = await API.get(`/api/oauth/discord?code=${code}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/setting');
} else {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/');
}
} else {
showError(message);
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
navigate('/setting'); // in case this is failed to bind GitHub
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, count);
}
};
useEffect(() => {
let code = searchParams.get('code');
sendCode(code, 0).then();
}, []);
return (
<Segment style={{ minHeight: '300px' }}>
<Dimmer active inverted>
<Loader size='large'>{prompt}</Loader>
</Dimmer>
</Segment>
);
};
export default DiscordOAuth;

View File

@@ -0,0 +1,57 @@
import React, { useContext, useEffect, useState } from 'react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers';
import { UserContext } from '../context/User';
const GoogleOAuth = () => {
const [searchParams, setSearchParams] = useSearchParams();
const [userState, userDispatch] = useContext(UserContext);
const [prompt, setPrompt] = useState('处理中...');
const [processing, setProcessing] = useState(true);
let navigate = useNavigate();
const sendCode = async (code, count) => {
const res = await API.get(`/api/oauth/google?code=${code}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/setting');
} else {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/');
}
} else {
showError(message);
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
navigate('/setting'); // in case this is failed to bind GitHub
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, count);
}
};
useEffect(() => {
let code = searchParams.get('code');
sendCode(code, 0).then();
}, []);
return (
<Segment style={{ minHeight: '300px' }}>
<Dimmer active inverted>
<Loader size='large'>{prompt}</Loader>
</Dimmer>
</Segment>
);
};
export default GoogleOAuth;

View File

@@ -2,7 +2,8 @@ import React, { useContext, useEffect, useState } from 'react';
import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react';
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getLogo, showError, showSuccess } from '../helpers';
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
import Turnstile from 'react-turnstile';
const LoginForm = () => {
const [inputs, setInputs] = useState({
@@ -14,6 +15,9 @@ const LoginForm = () => {
const [submitted, setSubmitted] = useState(false);
const { username, password } = inputs;
const [userState, userDispatch] = useContext(UserContext);
const [turnstileEnabled, setTurnstileEnabled] = useState(false);
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
const [turnstileToken, setTurnstileToken] = useState('');
let navigate = useNavigate();
const [status, setStatus] = useState({});
const logo = getLogo();
@@ -26,17 +30,34 @@ const LoginForm = () => {
if (status) {
status = JSON.parse(status);
setStatus(status);
if (status.turnstile_check) {
setTurnstileEnabled(true);
setTurnstileSiteKey(status.turnstile_site_key);
}
}
}, []);
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
const openGoogleOAuth = () => {
window.open(
`https://accounts.google.com/o/oauth2/v2/auth?client_id=${status.google_client_id}&redirect_uri=${window.location.origin}/oauth/google&response_type=code&scope=profile`
);
};
const onGitHubOAuthClicked = () => {
window.open(
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
);
};
const onDiscordOAuthClicked = () => {
window.open(
`https://discord.com/oauth2/authorize?response_type=code&client_id=${status.discord_client_id}&redirect_uri=${window.location.origin}/oauth/discord&scope=identify`,
);
};
const onWeChatLoginClicked = () => {
setShowWeChatLoginModal(true);
};
@@ -65,7 +86,12 @@ const LoginForm = () => {
async function handleSubmit(e) {
setSubmitted(true);
if (username && password) {
const res = await API.post(`/api/user/login`, {
if (turnstileEnabled && turnstileToken === '') {
showInfo('请稍后几秒重试Turnstile 正在检查用户环境!');
return;
}
const res = await API.post(`/api/user/login?turnstile=${turnstileToken}`, {
username,
password
});
@@ -108,6 +134,16 @@ const LoginForm = () => {
value={password}
onChange={handleChange}
/>
{turnstileEnabled ? (
<Turnstile
sitekey={turnstileSiteKey}
onVerify={(token) => {
setTurnstileToken(token);
}}
/>
) : (
<></>
)}
<Button color='green' fluid size='large' onClick={handleSubmit}>
登录
</Button>
@@ -123,28 +159,40 @@ const LoginForm = () => {
点击注册
</Link>
</Message>
{status.github_oauth || status.wechat_login ? (
{status.github_oauth || status.wechat_login || status.discord_oauth || status.google_oauth ? (
<>
<Divider horizontal>Or</Divider>
{status.github_oauth ? (
{status.discord_oauth && (
<Button
circular
color='blue'
icon='discord'
onClick={onDiscordOAuthClicked}
/>
)}
{status.github_oauth && (
<Button
circular
color='black'
icon='github'
onClick={onGitHubOAuthClicked}
/>
) : (
<></>
)}
{status.wechat_login ? (
{status.wechat_login && (
<Button
circular
color='green'
icon='wechat'
onClick={onWeChatLoginClicked}
/>
) : (
<></>
)}
{status.google_oauth && (
<Button
circular
color='red'
icon='google'
onClick={openGoogleOAuth}
/>
)}
</>
) : (

View File

@@ -112,7 +112,7 @@ const OtherSetting = () => {
<Form.Group widths='equal'>
<Form.TextArea
label='公告'
placeholder='在此输入新的公告内容,支持 Markdown & HTML 代码'
placeholder='在此输入新的公告内容'
value={inputs.Notice}
name='Notice'
onChange={handleInputChange}

View File

@@ -25,8 +25,6 @@ const PersonalSetting = () => {
const [loading, setLoading] = useState(false);
const [disableButton, setDisableButton] = useState(false);
const [countdown, setCountdown] = useState(30);
const [affLink, setAffLink] = useState("");
const [systemToken, setSystemToken] = useState("");
useEffect(() => {
let status = localStorage.getItem('status');
@@ -61,10 +59,8 @@ const PersonalSetting = () => {
const res = await API.get('/api/user/token');
const { success, message, data } = res.data;
if (success) {
setSystemToken(data);
setAffLink("");
await copy(data);
showSuccess(`令牌已重置并已复制到剪贴板`);
showSuccess(`令牌已重置并已复制到剪贴板${data}`);
} else {
showError(message);
}
@@ -75,27 +71,13 @@ const PersonalSetting = () => {
const { success, message, data } = res.data;
if (success) {
let link = `${window.location.origin}/register?aff=${data}`;
setAffLink(link);
setSystemToken("");
await copy(link);
showSuccess(`邀请链接已复制到剪切板`);
showNotice(`邀请链接已复制到剪切板${link}`);
} else {
showError(message);
}
};
const handleAffLinkClick = async (e) => {
e.target.select();
await copy(e.target.value);
showSuccess(`邀请链接已复制到剪切板`);
};
const handleSystemTokenClick = async (e) => {
e.target.select();
await copy(e.target.value);
showSuccess(`系统令牌已复制到剪切板`);
};
const deleteAccount = async () => {
if (inputs.self_account_deletion_confirmation !== userState.user.username) {
showError('请输入你的账户名以确认删除!');
@@ -130,12 +112,24 @@ const PersonalSetting = () => {
}
};
const openGoogleOAuth = () => {
window.open(
`https://accounts.google.com/o/oauth2/v2/auth?client_id=${status.google_client_id}&redirect_uri=${window.location.origin}/oauth/google&response_type=code&scope=https://www.googleapis.com/auth/userinfo.profile`
);
};
const openGitHubOAuth = () => {
window.open(
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
);
};
const openDiscordOAuth = () => {
window.open(
`https://discord.com/api/oauth2/authorize?client_id=${status.discord_client_id}&response_type=code&redirect_uri=${window.location.origin}/oauth/discord&scope=identify`,
);
};
const sendVerificationCode = async () => {
setDisableButton(true);
if (inputs.email === '') return;
@@ -186,25 +180,6 @@ const PersonalSetting = () => {
<Button onClick={() => {
setShowAccountDeleteModal(true);
}}>删除个人账户</Button>
{systemToken && (
<Form.Input
fluid
readOnly
value={systemToken}
onClick={handleSystemTokenClick}
style={{ marginTop: '10px' }}
/>
)}
{affLink && (
<Form.Input
fluid
readOnly
value={affLink}
onClick={handleAffLinkClick}
style={{ marginTop: '10px' }}
/>
)}
<Divider />
<Header as='h3'>账号绑定</Header>
{
@@ -252,6 +227,17 @@ const PersonalSetting = () => {
<Button onClick={openGitHubOAuth}>绑定 GitHub 账号</Button>
)
}
{
status.discord_oauth && (
<Button onClick={openDiscordOAuth}>绑定 Discord 账号</Button>
)
}
{
status.google_oauth && (
<Button onClick={openGoogleOAuth}>绑定 Google 账号</Button>
)
}
<Button
onClick={() => {
setShowEmailBindModal(true);
@@ -299,7 +285,6 @@ const PersonalSetting = () => {
) : (
<></>
)}
<div style={{ display: 'flex', justifyContent: 'space-between', marginTop: '1rem' }}>
<Button
color=''
fluid
@@ -307,17 +292,8 @@ const PersonalSetting = () => {
onClick={bindEmail}
loading={loading}
>
确认绑定
绑定
</Button>
<div style={{ width: '1rem' }}></div>
<Button
fluid
size='large'
onClick={() => setShowEmailBindModal(false)}
>
取消
</Button>
</div>
</Form>
</Modal.Description>
</Modal.Content>
@@ -329,9 +305,8 @@ const PersonalSetting = () => {
size={'tiny'}
style={{ maxWidth: '450px' }}
>
<Modal.Header>危险操作</Modal.Header>
<Modal.Header>确认删除自己的帐户</Modal.Header>
<Modal.Content>
<Message>您正在删除自己的帐户将清空所有数据且不可恢复</Message>
<Modal.Description>
<Form size='large'>
<Form.Input
@@ -351,25 +326,15 @@ const PersonalSetting = () => {
) : (
<></>
)}
<div style={{ display: 'flex', justifyContent: 'space-between', marginTop: '1rem' }}>
<Button
color='red'
fluid
size='large'
onClick={deleteAccount}
loading={loading}
>
确认删除
</Button>
<div style={{ width: '1rem' }}></div>
<Button
fluid
size='large'
onClick={() => setShowAccountDeleteModal(false)}
>
取消
</Button>
</div>
<Button
color='red'
fluid
size='large'
onClick={deleteAccount}
loading={loading}
>
删除
</Button>
</Form>
</Modal.Description>
</Modal.Content>

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Label, Popup, Pagination, Table } from 'semantic-ui-react';
import { Button, Form, Label, Message, Pagination, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, copy, showError, showInfo, showSuccess, showWarning, timestamp2string } from '../helpers';
@@ -240,25 +240,15 @@ const RedemptionsTable = () => {
>
复制
</Button>
<Popup
trigger={
<Button size='small' negative>
删除
</Button>
}
on='click'
flowing
hoverable
<Button
size={'small'}
negative
onClick={() => {
manageRedemption(redemption.id, 'delete', idx);
}}
>
<Button
negative
onClick={() => {
manageRedemption(redemption.id, 'delete', idx);
}}
>
确认删除
</Button>
</Popup>
删除
</Button>
<Button
size={'small'}
disabled={redemption.status === 3} // used

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react';
import { Button, Divider, Form, Grid, Header, Modal, Message } from 'semantic-ui-react';
import { API, removeTrailingSlash, showError } from '../helpers';
import { Divider, Form, Grid, Header, Message } from 'semantic-ui-react';
import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers';
const SystemSetting = () => {
let [inputs, setInputs] = useState({
@@ -8,8 +8,11 @@ const SystemSetting = () => {
PasswordRegisterEnabled: '',
EmailVerificationEnabled: '',
GitHubOAuthEnabled: '',
DiscordOAuthEnabled: '',
GitHubClientId: '',
GitHubClientSecret: '',
DiscordClientId: '',
DiscordClientSecret: '',
Notice: '',
SMTPServer: '',
SMTPPort: '',
@@ -22,18 +25,16 @@ const SystemSetting = () => {
WeChatServerAddress: '',
WeChatServerToken: '',
WeChatAccountQRCodeImageURL: '',
GoogleOAuthEnabled: '',
GoogleClientId: '',
GoogleClientSecret: '',
TurnstileCheckEnabled: '',
TurnstileSiteKey: '',
TurnstileSecretKey: '',
RegisterEnabled: '',
EmailDomainRestrictionEnabled: '',
EmailDomainWhitelist: ''
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]);
const [restrictedDomainInput, setRestrictedDomainInput] = useState('');
const [showPasswordWarningModal, setShowPasswordWarningModal] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
@@ -43,15 +44,8 @@ const SystemSetting = () => {
data.forEach((item) => {
newInputs[item.key] = item.value;
});
setInputs({
...newInputs,
EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',')
});
setInputs(newInputs);
setOriginInputs(newInputs);
setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => {
return { key: item, text: item, value: item };
}));
} else {
showError(message);
}
@@ -68,9 +62,10 @@ const SystemSetting = () => {
case 'PasswordRegisterEnabled':
case 'EmailVerificationEnabled':
case 'GitHubOAuthEnabled':
case 'DiscordOAuthEnabled':
case 'WeChatAuthEnabled':
case 'GoogleOAuthEnabled':
case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled':
case 'RegisterEnabled':
value = inputs[key] === 'true' ? 'false' : 'true';
break;
@@ -83,12 +78,7 @@ const SystemSetting = () => {
});
const { success, message } = res.data;
if (success) {
if (key === 'EmailDomainWhitelist') {
value = value.split(',');
}
setInputs((inputs) => ({
...inputs, [key]: value
}));
setInputs((inputs) => ({ ...inputs, [key]: value }));
} else {
showError(message);
}
@@ -96,23 +86,21 @@ const SystemSetting = () => {
};
const handleInputChange = async (e, { name, value }) => {
if (name === 'PasswordLoginEnabled' && inputs[name] === 'true') {
// block disabling password login
setShowPasswordWarningModal(true);
return;
}
if (
name === 'Notice' ||
name.startsWith('SMTP') ||
name === 'ServerAddress' ||
name === 'GitHubClientId' ||
name === 'GitHubClientSecret' ||
name === 'DiscordClientId' ||
name === 'DiscordClientSecret' ||
name === 'WeChatServerAddress' ||
name === 'WeChatServerToken' ||
name === 'WeChatAccountQRCodeImageURL' ||
name === 'GoogleClientId' ||
name === 'GoogleClientSecret' ||
name === 'TurnstileSiteKey' ||
name === 'TurnstileSecretKey' ||
name === 'EmailDomainWhitelist'
name === 'TurnstileSecretKey'
) {
setInputs((inputs) => ({ ...inputs, [name]: value }));
} else {
@@ -149,16 +137,6 @@ const SystemSetting = () => {
}
};
const submitEmailDomainWhitelist = async () => {
if (
originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') &&
inputs.SMTPToken !== ''
) {
await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(','));
}
};
const submitWeChat = async () => {
if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) {
await updateOption(
@@ -183,6 +161,18 @@ const SystemSetting = () => {
}
};
const submitGoogleOAuth = async () => {
if (originInputs['GoogleClientId'] !== inputs.GoogleClientId) {
await updateOption('GoogleClientId', inputs.GoogleClientId);
}
if (
originInputs['GoogleClientSecret'] !== inputs.GoogleClientSecret &&
inputs.GoogleClientSecret !== ''
) {
await updateOption('GoogleClientSecret', inputs.GoogleClientSecret);
}
};
const submitGitHubOAuth = async () => {
if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) {
await updateOption('GitHubClientId', inputs.GitHubClientId);
@@ -195,6 +185,18 @@ const SystemSetting = () => {
}
};
const submitDiscordOAuth = async () => {
if (originInputs['DiscordClientId'] !== inputs.DiscordClientId) {
await updateOption('DiscordClientId', inputs.DiscordClientId);
}
if (
originInputs['DiscordClientSecret'] !== inputs.DiscordClientSecret &&
inputs.DiscordClientSecret !== ''
) {
await updateOption('DiscordClientSecret', inputs.DiscordClientSecret);
}
};
const submitTurnstile = async () => {
if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) {
await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey);
@@ -207,22 +209,6 @@ const SystemSetting = () => {
}
};
const submitNewRestrictedDomain = () => {
const localDomainList = inputs.EmailDomainWhitelist;
if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) {
setRestrictedDomainInput('');
setInputs({
...inputs,
EmailDomainWhitelist: [...localDomainList, restrictedDomainInput],
});
setEmailDomainWhitelist([...EmailDomainWhitelist, {
key: restrictedDomainInput,
text: restrictedDomainInput,
value: restrictedDomainInput,
}]);
}
}
return (
<Grid columns={1}>
<Grid.Column>
@@ -249,32 +235,6 @@ const SystemSetting = () => {
name='PasswordLoginEnabled'
onChange={handleInputChange}
/>
{
showPasswordWarningModal &&
<Modal
open={showPasswordWarningModal}
onClose={() => setShowPasswordWarningModal(false)}
size={'tiny'}
style={{ maxWidth: '450px' }}
>
<Modal.Header>警告</Modal.Header>
<Modal.Content>
<p>取消密码登录将导致所有未绑定其他登录方式的用户包括管理员无法通过密码登录确认取消</p>
</Modal.Content>
<Modal.Actions>
<Button onClick={() => setShowPasswordWarningModal(false)}>取消</Button>
<Button
color='yellow'
onClick={async () => {
setShowPasswordWarningModal(false);
await updateOption('PasswordLoginEnabled', 'false');
}}
>
确定
</Button>
</Modal.Actions>
</Modal>
}
<Form.Checkbox
checked={inputs.PasswordRegisterEnabled === 'true'}
label='允许通过密码进行注册'
@@ -287,12 +247,24 @@ const SystemSetting = () => {
name='EmailVerificationEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.DiscordOAuthEnabled === 'true'}
label='允许通过 Discord 账户登录和注册'
name='DiscordOAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.GitHubOAuthEnabled === 'true'}
label='允许通过 GitHub 账户登录 & 注册'
name='GitHubOAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.GoogleOAuthEnabled === 'true'}
label='允许通过 Google 账户登录和注册'
name='GoogleOAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.WeChatAuthEnabled === 'true'}
label='允许通过微信登录 & 注册'
@@ -315,54 +287,6 @@ const SystemSetting = () => {
/>
</Form.Group>
<Divider />
<Header as='h3'>
配置邮箱域名白名单
<Header.Subheader>用以防止恶意用户利用临时邮箱批量注册</Header.Subheader>
</Header>
<Form.Group widths={3}>
<Form.Checkbox
label='启用邮箱域名白名单'
name='EmailDomainRestrictionEnabled'
onChange={handleInputChange}
checked={inputs.EmailDomainRestrictionEnabled === 'true'}
/>
</Form.Group>
<Form.Group widths={2}>
<Form.Dropdown
label='允许的邮箱域名'
placeholder='允许的邮箱域名'
name='EmailDomainWhitelist'
required
fluid
multiple
selection
onChange={handleInputChange}
value={inputs.EmailDomainWhitelist}
autoComplete='new-password'
options={EmailDomainWhitelist}
/>
<Form.Input
label='添加新的允许的邮箱域名'
action={
<Button type='button' onClick={() => {
submitNewRestrictedDomain();
}}>填入</Button>
}
onKeyDown={(e) => {
if (e.key === 'Enter') {
submitNewRestrictedDomain();
}
}}
autoComplete='new-password'
placeholder='输入新的允许的邮箱域名'
value={restrictedDomainInput}
onChange={(e, { value }) => {
setRestrictedDomainInput(value);
}}
/>
</Form.Group>
<Form.Button onClick={submitEmailDomainWhitelist}>保存邮箱域名白名单设置</Form.Button>
<Divider />
<Header as='h3'>
配置 SMTP
<Header.Subheader>用以支持系统的邮件发送</Header.Subheader>
@@ -408,7 +332,7 @@ const SystemSetting = () => {
onChange={handleInputChange}
type='password'
autoComplete='new-password'
checked={inputs.RegisterEnabled === 'true'}
value={inputs.SMTPToken}
placeholder='敏感信息不会发送到前端显示'
/>
</Form.Group>
@@ -496,6 +420,82 @@ const SystemSetting = () => {
保存 WeChat Server 设置
</Form.Button>
<Divider />
<Header as='h3'>
配置 Discord OAuth 应用程序
<Header.Subheader>
用以支持通过 Discord 进行登录注册
<a href='https://discord.com/developers/applications' target='_blank'>
点击此处
</a>
管理你的 Discord OAuth App
</Header.Subheader>
</Header>
<Message>
Homepage URL <code>{inputs.ServerAddress}</code>
Authorization callback URL {' '}
<code>{`${inputs.ServerAddress}/oauth/discord`}</code>
</Message>
<Form.Group widths={3}>
<Form.Input
label='Discord 客户 ID'
name='DiscordClientId'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.DiscordClientId}
placeholder='输入您注册的 Discord OAuth APP 的 ID'
/>
<Form.Input
label='Discord 客户秘密'
name='DiscordClientSecret'
onChange={handleInputChange}
type='password'
autoComplete='new-password'
value={inputs.DiscordClientSecret}
placeholder='敏感信息不会发送到前端显示'
/>
</Form.Group>
<Form.Button onClick={submitDiscordOAuth}>
保存 Discord OAuth 设置
</Form.Button>
<Divider />
<Header as='h3'>
配置 Google OAuth 应用程序
<Header.Subheader>
用以支持通过 Google 进行登录注册
<a href='https://console.cloud.google.com/' target='_blank'>
点击此处
</a>
管理你的 Google OAuth App
</Header.Subheader>
</Header>
<Message>
Homepage URL <code>{inputs.ServerAddress}</code>
Authorization callback URL {' '}
<code>{`${inputs.ServerAddress}/oauth/google`}</code>
</Message>
<Form.Group widths={3}>
<Form.Input
label='Google 客户 ID'
name='GoogleClientId'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.GoogleClientId}
placeholder='输入您注册的 Google OAuth APP 的 ID'
/>
<Form.Input
label='Google 客户秘密'
name='GoogleClientSecret'
onChange={handleInputChange}
type='password'
autoComplete='new-password'
value={inputs.GoogleClientSecret}
placeholder='敏感信息不会发送到前端显示'
/>
</Form.Group>
<Form.Button onClick={submitGoogleOAuth}>
保存 Google OAuth 设置
</Form.Button>
<Divider />
<Header as='h3'>
配置 Turnstile
<Header.Subheader>

View File

@@ -1,22 +1,11 @@
import React, { useEffect, useState } from 'react';
import { Button, Dropdown, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import { Button, Form, Label, Modal, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, copy, showError, showSuccess, showWarning, timestamp2string } from '../helpers';
import { ITEMS_PER_PAGE } from '../constants';
import { renderQuota } from '../helpers/render';
const COPY_OPTIONS = [
{ key: 'next', text: 'ChatGPT Next Web', value: 'next' },
{ key: 'ama', text: 'AMA 问天', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
];
const OPEN_LINK_OPTIONS = [
{ key: 'ama', text: 'AMA 问天', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
];
function renderTimestamp(timestamp) {
return (
<>
@@ -79,84 +68,6 @@ const TokensTable = () => {
const refresh = async () => {
setLoading(true);
await loadTokens(activePage - 1);
};
const onCopy = async (type, key) => {
let status = localStorage.getItem('status');
let serverAddress = '';
if (status) {
status = JSON.parse(status);
serverAddress = status.server_address;
}
if (serverAddress === '') {
serverAddress = window.location.origin;
}
let encodedServerAddress = encodeURIComponent(serverAddress);
const nextLink = localStorage.getItem('chat_link');
let nextUrl;
if (nextLink) {
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
} else {
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
}
let url;
switch (type) {
case 'ama':
url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`;
break;
case 'opencat':
url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`;
break;
case 'next':
url = nextUrl;
break;
default:
url = `sk-${key}`;
}
if (await copy(url)) {
showSuccess('已复制到剪贴板!');
} else {
showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。');
setSearchKeyword(url);
}
};
const onOpenLink = async (type, key) => {
let status = localStorage.getItem('status');
let serverAddress = '';
if (status) {
status = JSON.parse(status);
serverAddress = status.server_address;
}
if (serverAddress === '') {
serverAddress = window.location.origin;
}
let encodedServerAddress = encodeURIComponent(serverAddress);
const chatLink = localStorage.getItem('chat_link');
let defaultUrl;
if (chatLink) {
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`;
} else {
defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
}
let url;
switch (type) {
case 'ama':
url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`;
break;
case 'opencat':
url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`;
break;
default:
url = defaultUrl;
}
window.open(url, '_blank');
}
useEffect(() => {
@@ -324,51 +235,21 @@ const TokensTable = () => {
<Table.Cell>{token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)}</Table.Cell>
<Table.Cell>
<div>
<Button.Group color='green' size={'small'}>
<Button
size={'small'}
positive
onClick={async () => {
await onCopy('', token.key);
}}
>
复制
</Button>
<Dropdown
className='button icon'
floating
options={COPY_OPTIONS.map(option => ({
...option,
onClick: async () => {
await onCopy(option.value, token.key);
}
}))}
trigger={<></>}
/>
</Button.Group>
{' '}
<Button.Group color='blue' size={'small'}>
<Button
size={'small'}
positive
onClick={() => {
onOpenLink('', token.key);
}}>
聊天
</Button>
<Dropdown
className="button icon"
floating
options={OPEN_LINK_OPTIONS.map(option => ({
...option,
onClick: async () => {
await onOpenLink(option.value, token.key);
}
}))}
trigger={<></>}
/>
</Button.Group>
{' '}
<Button
size={'small'}
positive
onClick={async () => {
let key = "sk-" + token.key;
if (await copy(key)) {
showSuccess('已复制到剪贴板!');
} else {
showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。');
setSearchKeyword(key);
}
}}
>
复制
</Button>
<Popup
trigger={
<Button size='small' negative>

View File

@@ -227,7 +227,7 @@ const UsersTable = () => {
content={user.email ? user.email : '未绑定邮箱地址'}
key={user.username}
header={user.display_name ? user.display_name : user.username}
trigger={<span>{renderText(user.username, 15)}</span>}
trigger={<span>{renderText(user.username, 10)}</span>}
hoverable
/>
</Table.Cell>

View File

@@ -4,8 +4,6 @@ export const CHANNEL_OPTIONS = [
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 2, text: '代理API2D', value: 2, color: 'blue' },

View File

@@ -1,5 +1,5 @@
export const toastConstants = {
SUCCESS_TIMEOUT: 1500,
SUCCESS_TIMEOUT: 500,
INFO_TIMEOUT: 3000,
ERROR_TIMEOUT: 5000,
WARNING_TIMEOUT: 10000,

View File

@@ -1,11 +1,6 @@
import { toast } from 'react-toastify';
import { toastConstants } from '../constants';
import React from 'react';
const HTMLToastContent = ({ htmlContent }) => {
return <div dangerouslySetInnerHTML={{ __html: htmlContent }} />;
};
export default HTMLToastContent;
export function isAdmin() {
let user = localStorage.getItem('user');
if (!user) return false;
@@ -112,12 +107,8 @@ export function showInfo(message) {
toast.info(message, showInfoOptions);
}
export function showNotice(message, isHTML = false) {
if (isHTML) {
toast(<HTMLToastContent htmlContent={message} />, showNoticeOptions);
} else {
toast.info(message, showNoticeOptions);
}
export function showNotice(message) {
toast.info(message, showNoticeOptions);
}
export function openPage(url) {

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
import { useParams, useNavigate } from 'react-router-dom';
import { useParams } from 'react-router-dom';
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
import { CHANNEL_OPTIONS } from '../../constants';
@@ -12,14 +12,9 @@ const MODEL_MAPPING_EXAMPLE = {
const EditChannel = () => {
const params = useParams();
const navigate = useNavigate();
const channelId = params.id;
const isEdit = channelId !== undefined;
const [loading, setLoading] = useState(isEdit);
const handleCancel = () => {
navigate('/channel');
};
const originInputs = {
name: '',
type: 1,
@@ -27,6 +22,8 @@ const EditChannel = () => {
base_url: '',
other: '',
model_mapping: '',
allow_streaming: 1,
allow_non_streaming: 1,
models: [],
groups: ['default']
};
@@ -40,30 +37,6 @@ const EditChannel = () => {
const [customModel, setCustomModel] = useState('');
const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
if (name === 'type' && inputs.models.length === 0) {
let localModels = [];
switch (value) {
case 14:
localModels = ['claude-instant-1', 'claude-2'];
break;
case 11:
localModels = ['PaLM-2'];
break;
case 15:
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-v1', 'qwen-plus-v1'];
break;
case 16:
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
localModels = ['SparkDesk'];
break;
}
setInputs((inputs) => ({ ...inputs, models: localModels }));
}
};
const loadChannel = async () => {
@@ -123,6 +96,9 @@ const EditChannel = () => {
useEffect(() => {
let localModelOptions = [...originModelOptions];
if (!Array.isArray(inputs.models)) {
inputs.models = inputs.models.split(',');
}
inputs.models.forEach((model) => {
if (!localModelOptions.find((option) => option.key === model)) {
localModelOptions.push({
@@ -156,15 +132,17 @@ const EditChannel = () => {
showInfo('模型映射必须是合法的 JSON 格式!');
return;
}
// allow streaming and allow non streaming cannot be both false
if (inputs.allow_streaming === 2 && inputs.allow_non_streaming === 2) {
showInfo('流式请求和非流式请求不能同时禁用!');
return;
}
let localInputs = inputs;
if (localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
}
if (localInputs.type === 3 && localInputs.other === '') {
localInputs.other = '2023-06-01-preview';
}
if (localInputs.model_mapping === '') {
localInputs.model_mapping = '{}';
localInputs.other = '2023-03-15-preview';
}
let res;
localInputs.models = localInputs.models.join(',');
@@ -208,7 +186,7 @@ const EditChannel = () => {
<Message>
注意<strong>模型部署名称必须和模型名称保持一致</strong> One API model
参数替换为你的部署名称模型名称中的点会被剔除<a target='_blank'
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>
</Message>
<Form.Field>
<Form.Input
@@ -224,7 +202,7 @@ const EditChannel = () => {
<Form.Input
label='默认 API 版本'
name='other'
placeholder={'请输入默认 API 版本例如2023-06-01-preview该配置可以被实际的请求查询参数所覆盖'}
placeholder={'请输入默认 API 版本例如2023-03-15-preview该配置可以被实际的请求查询参数所覆盖'}
onChange={handleInputChange}
value={inputs.other}
autoComplete='new-password'
@@ -303,7 +281,7 @@ const EditChannel = () => {
<Input
action={
<Button type={'button'} onClick={() => {
if (customModel.trim() === '') return;
if (customModel.trim() === "") return;
if (inputs.models.includes(customModel)) return;
let localModels = [...inputs.models];
localModels.push(customModel);
@@ -311,7 +289,7 @@ const EditChannel = () => {
localModelOptions.push({
key: customModel,
text: customModel,
value: customModel
value: customModel,
});
setModelOptions(modelOptions => {
return [...modelOptions, ...localModelOptions];
@@ -329,7 +307,7 @@ const EditChannel = () => {
</div>
<Form.Field>
<Form.TextArea
label='模型重定向'
label='模型映射'
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
name='model_mapping'
onChange={handleInputChange}
@@ -338,6 +316,26 @@ const EditChannel = () => {
autoComplete='new-password'
/>
</Form.Field>
<Form.Field>
<Form.Checkbox
checked={inputs.allow_streaming === 1}
label='允许流式请求'
name='allow_streaming'
onChange={() => {
setInputs((inputs) => ({ ...inputs, allow_streaming: inputs.allow_streaming === 1 ? 2 : 1 }));
}}
/>
</Form.Field>
<Form.Field>
<Form.Checkbox
checked={inputs.allow_non_streaming === 1}
label='允许非流式请求'
name='allow_non_streaming'
onChange={() => {
setInputs((inputs) => ({ ...inputs, allow_non_streaming: inputs.allow_non_streaming === 1 ? 2 : 1 }));
}}
/>
</Form.Field>
{
batch ? <Form.Field>
<Form.TextArea
@@ -355,7 +353,7 @@ const EditChannel = () => {
label='密钥'
name='key'
required
placeholder={inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
placeholder={inputs.type === 15 ? "请输入 access token当前版本暂不支持自动刷新请每 30 天更新一次" : '请输入渠道对应的鉴权密钥'}
onChange={handleInputChange}
value={inputs.key}
autoComplete='new-password'
@@ -376,9 +374,9 @@ const EditChannel = () => {
inputs.type !== 3 && inputs.type !== 8 && (
<Form.Field>
<Form.Input
label='代理'
label='镜像'
name='base_url'
placeholder={'此项可选,用于通过代理站来进行 API 调用,请输入代理站地址格式为https://domain.com'}
placeholder={'此项可选,用于通过镜像站来进行 API 调用,请输入镜像站地址格式为https://domain.com'}
onChange={handleInputChange}
value={inputs.base_url}
autoComplete='new-password'
@@ -386,8 +384,7 @@ const EditChannel = () => {
</Form.Field>
)
}
<Button onClick={handleCancel}>取消</Button>
<Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
<Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>

View File

@@ -14,11 +14,10 @@ const Home = () => {
const { success, message, data } = res.data;
if (success) {
let oldNotice = localStorage.getItem('notice');
if (data !== oldNotice && data !== '') {
const htmlNotice = marked(data);
showNotice(htmlNotice, true);
localStorage.setItem('notice', data);
}
if (data !== oldNotice && data !== '') {
showNotice(data);
localStorage.setItem('notice', data);
}
} else {
showError(message);
}
@@ -65,7 +64,7 @@ const Home = () => {
<Card.Meta>系统信息总览</Card.Meta>
<Card.Description>
<p>名称{statusState?.status?.system_name}</p>
<p>版本{statusState?.status?.version ? statusState?.status?.version : "unknown"}</p>
<p>版本{statusState?.status?.version}</p>
<p>
源码
<a
@@ -98,12 +97,24 @@ const Home = () => {
? '已启用'
: '未启用'}
</p>
<p>
Discord 身份验证
{statusState?.status?.discord_oauth === true
? '已启用'
: '未启用'}
</p>
<p>
微信身份验证
{statusState?.status?.wechat_login === true
? '已启用'
: '未启用'}
</p>
<p>
Google 身份验证
{statusState?.status?.google_oauth === true
? '已启用'
: '未启用'}
</p>
<p>
Turnstile 用户校验
{statusState?.status?.turnstile_check === true

View File

@@ -1,12 +1,11 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Segment } from 'semantic-ui-react';
import { useParams, useNavigate } from 'react-router-dom';
import { useParams } from 'react-router-dom';
import { API, downloadTextAsFile, showError, showSuccess } from '../../helpers';
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
const EditRedemption = () => {
const params = useParams();
const navigate = useNavigate();
const redemptionId = params.id;
const isEdit = redemptionId !== undefined;
const [loading, setLoading] = useState(isEdit);
@@ -18,10 +17,6 @@ const EditRedemption = () => {
const [inputs, setInputs] = useState(originInputs);
const { name, quota, count } = inputs;
const handleCancel = () => {
navigate('/redemption');
};
const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
@@ -118,7 +113,6 @@ const EditRedemption = () => {
</>
}
<Button positive onClick={submit}>提交</Button>
<Button onClick={handleCancel}>取消</Button>
</Form>
</Segment>
</>

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
import { useParams, useNavigate } from 'react-router-dom';
import { useParams } from 'react-router-dom';
import { API, showError, showSuccess, timestamp2string } from '../../helpers';
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
@@ -17,13 +17,11 @@ const EditToken = () => {
};
const [inputs, setInputs] = useState(originInputs);
const { name, remain_quota, expired_time, unlimited_quota } = inputs;
const navigate = useNavigate();
const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
const handleCancel = () => {
navigate("/token");
}
const setExpiredTime = (month, day, hour, minute) => {
let now = new Date();
let timestamp = now.getTime() / 1000;
@@ -85,7 +83,7 @@ const EditToken = () => {
if (isEdit) {
showSuccess('令牌更新成功!');
} else {
showSuccess('令牌创建成功,请在列表页面点击复制获取令牌');
showSuccess('令牌创建成功!');
setInputs(originInputs);
}
} else {
@@ -152,9 +150,8 @@ const EditToken = () => {
</Form.Field>
<Button type={'button'} onClick={() => {
setUnlimitedQuota();
}}>{unlimited_quota ? '取消无限额度' : '设为无限额度'}</Button>
<Button floated='right' positive onClick={submit}>提交</Button>
<Button floated='right' onClick={handleCancel}>取消</Button>
}}>{unlimited_quota ? '取消无限额度' : '设为无限额度'}</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>
</>

View File

@@ -1,6 +1,6 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Segment } from 'semantic-ui-react';
import { useParams, useNavigate } from 'react-router-dom';
import { useParams } from 'react-router-dom';
import { API, showError, showSuccess } from '../../helpers';
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
@@ -13,13 +13,15 @@ const EditUser = () => {
display_name: '',
password: '',
github_id: '',
discord_id: '',
wechat_id: '',
google_id: '',
email: '',
quota: 0,
group: 'default'
});
const [groupOptions, setGroupOptions] = useState([]);
const { username, display_name, password, github_id, wechat_id, email, quota, group } =
const { username, display_name, password, github_id, wechat_id, email, quota, google_id, discord_id } =
inputs;
const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
@@ -36,10 +38,7 @@ const EditUser = () => {
showError(error.message);
}
};
const navigate = useNavigate();
const handleCancel = () => {
navigate("/setting");
}
const loadUser = async () => {
let res = undefined;
if (userId) {
@@ -169,6 +168,26 @@ const EditUser = () => {
readOnly
/>
</Form.Field>
<Form.Field>
<Form.Input
label='已绑定的 Discord 账户'
name='discord_id'
value={discord_id}
autoComplete='new-password'
placeholder='此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改'
readOnly
/>
</Form.Field>
<Form.Field>
<Form.Input
label='已绑定的 Google 账户'
name='google_id'
value={google_id}
autoComplete='new-password'
placeholder='此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改'
readOnly
/>
</Form.Field>
<Form.Field>
<Form.Input
label='已绑定的邮箱账户'
@@ -179,7 +198,6 @@ const EditUser = () => {
readOnly
/>
</Form.Field>
<Button onClick={handleCancel}>取消</Button>
<Button positive onClick={submit}>提交</Button>
</Form>
</Segment>