Compare commits

...

136 Commits

Author SHA1 Message Date
JustSong
2ba28c72cb feat: support function call for ali (close #1242) 2024-03-30 10:43:26 +08:00
JustSong
5e81e19bc8 fix: fix SQL channel selection algo (#1197) 2024-03-27 19:09:27 +08:00
JustSong
96d7a99312 fix: fix autofilled models are not correct 2024-03-24 23:12:32 +08:00
JustSong
24be9de098 chore: update copy 2024-03-24 23:01:03 +08:00
JustSong
5b349efff9 chore: fix berry copy 2024-03-24 22:57:24 +08:00
JustSong
f76c46d648 feat: add gemini-1.5-pro (#1211) 2024-03-24 22:50:09 +08:00
JustSong
cdfdeea3b4 feat: return token when calling post /api/token (close #1208) 2024-03-24 22:24:41 +08:00
JustSong
56ddbb842a fix: return pre-consumed quota when error happened for audio (close #1217) 2024-03-24 22:20:41 +08:00
JustSong
99f81a267c fix: fix xunfei error handling (close #1218) 2024-03-24 22:14:45 +08:00
xietong
c243cd5535 feat: 支持 ollama 的 embedding 接口 (#1221)
* 增加ollama的embedding接口

* chore: fix function name

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-03-24 21:51:31 +08:00
GuangxiaoLong
e96b173abe feat: 移除 azure model 的 TrimSuffix (#1193) 2024-03-24 21:47:46 +08:00
Benny
4ae311e964 docs: update README (#1186) 2024-03-17 21:06:36 +08:00
JustSong
b14cb748d8 chore: update copy 2024-03-17 19:39:00 +08:00
Ian Li
ade19ba4a2 feat: update default API version for Azure OpenAI (#994)
* feat: Update default API version for Azure OpenAI.

* chore: update other theme

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-03-17 19:34:21 +08:00
Ian Li
4d86d021c4 feat: support Azure OpenAI TTS. (#1177) 2024-03-17 19:30:50 +08:00
shuirong
7a44adb5a7 fix: fix panel cards style (#1171) 2024-03-17 19:26:12 +08:00
Benny
9821bc7281 feat: add user list sorting and pagination enhancements (#1178)
* feat: add user list sorting and pagination enhancements

* feat: add user list sorting for THEME=air

* feat: add token list sorting and pagination enhancements

* feat: add token list sorting for THEME=air
2024-03-17 19:25:36 +08:00
JustSong
08831881f1 feat: increase initial root user quota and support INITIAL_ROOT_TOKEN now (#1105) 2024-03-17 19:09:44 +08:00
JustSong
0eb2272bb7 chore: update copy 2024-03-17 18:12:49 +08:00
JustSong
704ec1a827 chore: update theme berry 2024-03-17 17:48:57 +08:00
Ghostz
1d7470d6ad fix: fix lingyiwanwu model ratio (#1182) 2024-03-17 17:04:29 +08:00
JustSong
1185303346 chore: update comments 2024-03-17 14:10:35 +08:00
JustSong
c212fcf8d7 docs: update readme 2024-03-17 14:00:33 +08:00
JustSong
c285e000cc chore: remove default scroll bar 2024-03-16 16:16:44 +08:00
JustSong
d25ed4c009 chore: update name 2024-03-16 15:55:31 +08:00
JustSong
7400885fbb fix: fix error 2024-03-16 15:41:43 +08:00
GAI Group
11af81eb39 feat: add new theme air (#1167)
* chore: add theme air with new-api main branch v0.2.0.3-alpha.1(first step)

* feat: 完成渠道界面

* chore: 优化渠道界面样式问题

* feat: 完成兑换码界面

* feat: 完成充值(钱包)界面

* chore: 初代air主题将使用default主题的运营设置界面、系统设置界面、其他设置界面

* feat: 完成日志界面

* feat: 完成用户管理界面

* feat: 完成个人设置界面

* feat: 完成令牌界面

* chore: 优化令牌界面逻辑

* feat: 修改版权信息

* chore: make necessary changes

---------

Co-authored-by: Calon <1808837298@qq.com>
Co-authored-by: Apple\Apple <zeraturing@foxmail.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-03-16 15:29:35 +08:00
majian
205aba694f chore: limit the temperature and top_p parameter value range to (0.0, 1) for zhipu (#1091) 2024-03-16 13:39:30 +08:00
dependabot[bot]
8dac3afebc chore(deps): bump github.com/jackc/pgx/v5 from 5.3.1 to 5.5.4 (#1157)
Bumps [github.com/jackc/pgx/v5](https://github.com/jackc/pgx) from 5.3.1 to 5.5.4.
- [Changelog](https://github.com/jackc/pgx/blob/master/CHANGELOG.md)
- [Commits](https://github.com/jackc/pgx/compare/v5.3.1...v5.5.4)

---
updated-dependencies:
- dependency-name: github.com/jackc/pgx/v5
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-03-16 13:31:59 +08:00
zu1k
a07791bf93 fix: change Moonshot value to 25 (#1158) 2024-03-16 13:29:19 +08:00
JustSong
4bb662c0e4 docs: update pull_request_template.md 2024-03-16 13:28:44 +08:00
Benny
4998d58319 fix: fix ratio of gpt-3.5-turbo (close #1011) (#1163) 2024-03-16 13:26:11 +08:00
E.da
190203cf8f fix: 修复berry主题下令牌编辑后点击新建弹窗值的初始化问题 (#1165)
* Update OtherSetting.js

调整berry主题页脚`label`表述

* 修复`berry`主题下`令牌`在`修改`后点击`新建`弹窗值的初始化问题

### 问题描述

在`berry`主题中,存在一个问题,当用户在修改一个令牌后点击新建令牌时,新建令牌弹窗会错误地展示上一次编辑的值。

### 复现步骤

1. 导航至`令牌`管理页面。
2. 选择任意令牌进行`编辑`。
3. 在编辑界面,点击`取消`返回。
4. 点击`+新建令牌`按钮打开新建令牌弹窗。

### 预期结果

新建令牌弹窗应显示空白表单,等待用户输入新的令牌信息。

### 实际结果

新建令牌弹窗错误地展示了之前编辑的令牌信息。
2024-03-16 13:23:04 +08:00
JustSong
6325c8e0b4 ci: fix ci 2024-03-15 23:47:54 +08:00
JustSong
b204f6d82b docs: update README 2024-03-15 00:55:28 +08:00
JustSong
752639560f feat: able to use separated database for table logs 2024-03-15 00:30:15 +08:00
JustSong
996f4d99dd ci: fix ci 2024-03-14 23:53:25 +08:00
warjiang
ebfee3b46c feat: add support for private registry in docker-compose.yml (#1103) 2024-03-14 23:47:46 +08:00
dependabot[bot]
3e2e805d61 chore(deps): bump google.golang.org/protobuf from 1.30.0 to 1.33.0 (#1145)
Bumps google.golang.org/protobuf from 1.30.0 to 1.33.0.

---
updated-dependencies:
- dependency-name: google.golang.org/protobuf
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-03-14 23:46:17 +08:00
E.da
3edf7247c4 fix: fix theme berry copy (#1148)
调整berry主题页脚`label`表述
2024-03-14 23:45:50 +08:00
afafw
0926b6206b chore: update client name (#934) 2024-03-14 23:44:46 +08:00
JustSong
7cd57f3125 chore: update ratio for baidu embedding 2024-03-14 23:36:10 +08:00
Jguobao
66efabd5ae fix: fix baidu url check (#1143)
添加百度的另外3个向量模型【"bge-large-zh",
	"bge-large-en",
	"tao-8k",
】
2024-03-14 23:31:07 +08:00
JustSong
8ede66a896 fix: fix ci 2024-03-14 23:27:47 +08:00
JustSong
b169173860 fix: force set Accept header for ali stream request (close #1151) 2024-03-14 23:20:38 +08:00
JustSong
f33555ae78 fix: update max token for test (close #1154) 2024-03-14 23:17:19 +08:00
JustSong
c28ec10795 fix: fix cors for dashboard api 2024-03-14 23:14:39 +08:00
JustSong
e3767cbb07 fix: fix haiku model name (close #1149) 2024-03-14 23:13:05 +08:00
JustSong
be9eb59fbb feat: support lingyiwanwu 2024-03-14 23:11:36 +08:00
JustSong
89e111ac69 ci: fix ci condition 2024-03-14 01:17:19 +08:00
JustSong
2dcef85285 feat: support ollama now (close #870) 2024-03-14 01:02:47 +08:00
JustSong
79d0cd378a fix: fix baidu system prompt (close #1079) 2024-03-13 22:56:54 +08:00
JustSong
e99150bdb9 fix: make quota int64 2024-03-13 20:00:51 +08:00
JustSong
a72e5fcc9e fix: when cached quota is too low, force refresh it 2024-03-13 19:38:44 +08:00
JustSong
0710f8cd66 fix: when cached quota is too low, force refresh it 2024-03-13 19:26:24 +08:00
JustSong
49cad7d4a5 feat: update func ShouldDisableChannel for claude 2024-03-13 19:11:30 +08:00
JustSong
a90161cf00 chore: drop idx_channels_key on start 2024-03-11 02:24:58 +08:00
sparanoid
a45fc7d736 fix: model name typo (#1109) 2024-03-11 00:44:49 +08:00
JustSong
45940dcb12 chore: add more info for panic fix 2024-03-10 23:59:35 +08:00
JustSong
969042b001 chore: only use one log file (close #1116) 2024-03-10 23:44:48 +08:00
JustSong
7e7369dbc4 fix: only disable channel when allowed 2024-03-10 23:41:16 +08:00
JustSong
e54e647170 chore: remove useless code 2024-03-10 23:36:29 +08:00
JustSong
358920c858 fix: remove index idx_channels_key (close #644) 2024-03-10 23:27:22 +08:00
JustSong
1ea598c773 feat: check claude's error response 2024-03-10 20:39:55 +08:00
JustSong
796be42487 feat: update ratio config if missing 2024-03-10 19:29:42 +08:00
JustSong
5b50eb94e5 feat: able to send alert message via message pusher (close #993) 2024-03-10 19:16:06 +08:00
JustSong
71c61365eb feat: able to only test disabled channels (#1090) 2024-03-10 18:34:57 +08:00
JustSong
b09f979b80 fix: add missing turnstile setup (close #1015) 2024-03-10 18:15:24 +08:00
JustSong
12440874b0 feat: able to disable channel by success rate 2024-03-10 17:57:47 +08:00
JustSong
6ebc99460e fix: add user to blacklist when it's banned or deleted, and make deletion soft (close #473, close #791) 2024-03-10 15:56:19 +08:00
JustSong
27ad8bfb98 feat: able to search channel type now 2024-03-10 15:00:33 +08:00
JustSong
8388aa537f chore: able to search channel now 2024-03-10 14:59:57 +08:00
JustSong
2346bf70af fix: check response type when expect stream response 2024-03-10 14:59:40 +08:00
JustSong
f05b403ca5 feat: use real system prompt now (close #1079) 2024-03-10 14:32:30 +08:00
JustSong
b33616df44 feat: support groq now (close #1087) 2024-03-10 14:09:44 +08:00
JustSong
cf16f44970 feat: load channel models from server 2024-03-09 02:28:23 +08:00
JustSong
bf2e26a48f feat: support claude-3 (close #1080, close #1094) 2024-03-09 01:12:47 +08:00
momomobinx
4fb22ad4ce feat: support third part models of baidu (#1046)
百度千帆平台上的第三方大模型调用
2024-03-03 23:50:28 +08:00
JustSong
95cfb8e8c9 fix: using the first available model if default model is not found (close #1021) 2024-03-03 22:58:41 +08:00
JustSong
c6ace985c2 fix: set missing ali parameters (close #1028) 2024-03-03 22:51:01 +08:00
JustSong
10a926b8f3 feat: only use the top priority when first retry (#1048) 2024-03-03 22:16:34 +08:00
JustSong
2df877a352 feat: switch priority when retry (close #1048) 2024-03-03 22:14:07 +08:00
JustSong
9d8967f7d3 feat: support Mistral's models now (close #1051) 2024-03-03 21:46:45 +08:00
JustSong
b35f3523d3 feat: add gemini model alias (close #1064) 2024-03-03 21:03:04 +08:00
JustSong
82e916b5ff fix: fix azure test (close #1069) 2024-03-03 20:51:28 +08:00
JustSong
de18d6fe16 refactor: refactor image relay (close #1068) 2024-03-03 19:30:11 +08:00
JustSong
1d0b7fb5ae feat: support chatglm-4 (close #1045, close #952, close #952, close #943) 2024-03-02 03:05:25 +08:00
JustSong
f9490bb72e fix: able to use updated default ratio 2024-03-02 01:32:04 +08:00
JustSong
76467285e8 docs: update readme 2024-03-02 01:25:21 +08:00
JustSong
df1fd9aa81 feat: support minimax's models now (close #354) 2024-03-02 01:24:28 +08:00
JustSong
614c2e0442 feat: support baichuan's models now (close #1057) 2024-03-02 00:55:48 +08:00
JustSong
eac6a0b9aa fix: fix version is blank 2024-03-02 00:03:29 +08:00
JustSong
b747cdbc6f fix: fix getAndValidateTextRequest failed: unexpected end of JSON input (close #1043) 2024-02-26 22:52:16 +08:00
JustSong
6b27d6659a fix: add role for ChatCompletionsStreamResponseChoice.Delta 2024-02-25 19:49:22 +08:00
JustSong
dc5b781191 fix: fix stream response id 2024-02-25 19:47:59 +08:00
JustSong
c880b4a9a3 fix: fix missing index in ChatCompletionsStreamResponseChoice (#1037) 2024-02-25 19:17:37 +08:00
JustSong
565ea58e68 feat: built in retry supported (close #1036, close #770) 2024-02-25 19:01:49 +08:00
JustSong
f141a37a9e fix: fix "error update user quota cache: Error 1040: Too many connections" 2024-02-25 16:58:14 +08:00
JustSong
5b78886ad3 fix: fix i18n 2024-02-25 16:53:46 +08:00
JustSong
87c7c4f0e6 fix: rm history build before building 2024-02-25 02:07:34 +08:00
JustSong
4c4a873890 fix: add an ending line for THEMES 2024-02-25 01:59:40 +08:00
JustSong
0664bdfda1 fix: fix build.sh (close #1026) 2024-02-25 01:53:27 +08:00
JustSong
32387d9c20 fix: fix version is blank 2024-02-21 22:21:01 +08:00
JustSong
bd888f2eb7 fix: fix prompt token is zero (close #1023) 2024-02-21 22:19:42 +08:00
JustSong
cece77e533 fix: fix model list 2024-02-19 22:20:18 +08:00
JustSong
2a5468e23c refactor: remove useless button (close #1014) 2024-02-18 22:21:37 +08:00
JustSong
d0e415893b fix: fix SparkDesk model name 2024-02-18 17:16:11 +08:00
JustSong
6cf5ce9a7a fix: fix SparkDesk model name 2024-02-18 17:11:16 +08:00
JustSong
f598b9df87 feat: add new SparkDesk models 2024-02-18 17:02:36 +08:00
JustSong
532c50d212 fix: fix channel table page copy 2024-02-18 16:19:14 +08:00
JustSong
2acc2f5017 feat: support moonshot now (close #804) 2024-02-18 16:17:19 +08:00
JustSong
604ac56305 fix: set seed parameter for qwen (close #1005) 2024-02-18 15:01:09 +08:00
JustSong
9383b638a6 feat: add ChatPro & ChatStd for tencent (#1010) 2024-02-18 14:40:01 +08:00
JustSong
28d512a675 refactor: delete useless code 2024-02-18 02:23:31 +08:00
JustSong
de9a58ca0b refactor: use config field to save config 2024-02-18 02:22:50 +08:00
JustSong
1aa374ccfb refactor: use adaptor to do relay & test 2024-02-18 00:15:31 +08:00
Laisky.Cai
d548a01c59 feat: Handle errors, validate model names, and calculate quota usage (#978)
- Improved error handling in various modules for better stability and responsiveness.
- Optimized code in several files for improved efficiency and readability.
- Enhanced user experience by providing more detailed error responses in the controller.
- Strengthened security by ignoring sensitive files in `.gitignore`.
2024-02-12 21:35:40 +08:00
JustSong
2cd1a78203 chore: update module name 2024-01-28 19:38:58 +08:00
JustSong
b9d3cb0c45 refactor: split RelayTextHelper function 2024-01-28 19:14:46 +08:00
JustSong
ea407f0054 feat: able to set completion ration now (close #968) 2024-01-28 16:45:54 +08:00
Benny
26e2e646cb feat: sync models with OpenAI (#971)
* add new 0125 chat models and embedding-3 models

* refine the step of manually deploying

* add gpt-4-turbo-preview
2024-01-28 16:09:21 +08:00
yongman
4f214c48c6 fix: fix primary chat button (#951)
Signed-off-by: yongman <yming0221@gmail.com>
2024-01-21 23:27:34 +08:00
JustSong
2d760d4a01 refactor: refactor relay part (#957)
* refactor: refactor relay part

* refactor: refactor config part
2024-01-21 23:21:42 +08:00
Buer
e2ed0399f0 fix: fix aff not effective (#937) 2024-01-20 12:38:06 +08:00
JustSong
eed9f5fdf0 refactor: refactor relay part (#935) 2024-01-14 19:21:03 +08:00
JustSong
f2c51a494c feat: able to login via email (close #921) 2024-01-14 14:08:39 +08:00
Calcium-Ion
8a4d6f3327 fix: remove useless wrong index (#916)
* add sqlite busy_timeout=3000

* chore: update impl

* fix: fix JSON tag in Log struct

* fix: 修复高并发下,高额度用户使用低额度令牌没有预扣费而导致令牌大额欠费

* Revert "fix: 修复高并发下,高额度用户使用低额度令牌没有预扣费而导致令牌大额欠费"

This reverts commit f0ffe14437.

* fix: remove wrong index

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-14 13:48:16 +08:00
Buer
cf4e33cb12 fix: fix bugs with theme berry (#931)
* fix: home page & logo style issue

* improve: Enhanced user experience by improving the channel selection box

* fix: key cannot be activated after expiration
2024-01-14 13:22:31 +08:00
JustSong
5d60305570 docs: update readme (close #930) 2024-01-14 13:00:59 +08:00
JustSong
d062bc60e4 chore: update ui copy 2024-01-07 19:18:52 +08:00
JustSong
39c1882970 chore: add back THEMES 2024-01-07 19:01:31 +08:00
JustSong
9c42c7dfd9 docs: update theme readme 2024-01-07 18:59:26 +08:00
JustSong
903aaeded0 chore: revert change 2024-01-07 18:47:39 +08:00
JustSong
bdd4be562d chore: add theme validation 2024-01-07 18:44:26 +08:00
Buer
37afb313b5 fix: fix some issues with berry (#913)
* fix: login address error

* fix: Normal users display profile menu

* fix: remove redundant code
2024-01-07 18:39:15 +08:00
JustSong
c9ebcab8b8 fix: fix theme logging 2024-01-07 18:02:59 +08:00
266 changed files with 16921 additions and 5046 deletions

View File

@@ -20,6 +20,13 @@ jobs:
- name: Check out the repo
uses: actions/checkout@v3
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info
run: |
git describe --tags > VERSION

View File

@@ -20,6 +20,13 @@ jobs:
- name: Check out the repo
uses: actions/checkout@v3
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info
run: |
git describe --tags > VERSION

View File

@@ -21,6 +21,13 @@ jobs:
- name: Check out the repo
uses: actions/checkout@v3
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info
run: |
git describe --tags > VERSION

View File

@@ -20,10 +20,16 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- uses: actions/setup-node@v3
with:
node-version: 16
- name: Build Frontend (theme default)
- name: Build Frontend
env:
CI: ""
run: |
@@ -38,7 +44,7 @@ jobs:
- name: Build Backend (amd64)
run: |
go mod download
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
- name: Build Backend (arm64)
run: |

View File

@@ -20,10 +20,16 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- uses: actions/setup-node@v3
with:
node-version: 16
- name: Build Frontend (theme default)
- name: Build Frontend
env:
CI: ""
run: |
@@ -38,7 +44,7 @@ jobs:
- name: Build Backend
run: |
go mod download
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos
go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos
- name: Release
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')

View File

@@ -23,10 +23,16 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- uses: actions/setup-node@v3
with:
node-version: 16
- name: Build Frontend (theme default)
- name: Build Frontend
env:
CI: ""
run: |
@@ -41,7 +47,7 @@ jobs:
- name: Build Backend
run: |
go mod download
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe
go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe
- name: Release
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')

3
.gitignore vendored
View File

@@ -6,4 +6,5 @@ upload
build
*.db-journal
logs
data
data
/web/node_modules

View File

@@ -12,6 +12,10 @@ WORKDIR /web/berry
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
WORKDIR /web/air
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2
ENV GO111MODULE=on \
@@ -23,7 +27,7 @@ ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /web/build ./web/build
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine

View File

@@ -134,12 +134,12 @@ The initial account username is `root` and password is `123456`.
git clone https://github.com/songquanpeng/one-api.git
# Build the frontend
cd one-api/web
cd one-api/web/default
npm install
npm run build
# Build the backend
cd ..
cd ../..
go mod download
go build -ldflags "-s -w" -o one-api
```
@@ -241,17 +241,19 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Example: `SESSION_SECRET=random_string`
3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0.
+ Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
4. `LOG_SQL_DSN`: When set, a separate database will be used for the `logs` table; please use MySQL or PostgreSQL.
+ Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs`
5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
+ Example: `SYNC_FREQUENCY=60`
6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
+ Example: `NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
+ Example: `CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
+ Example: `POLLING_INTERVAL=5`
### Command Line Parameters

View File

@@ -135,12 +135,12 @@ sudo service nginx restart
git clone https://github.com/songquanpeng/one-api.git
# フロントエンドのビルド
cd one-api/web
cd one-api/web/default
npm install
npm run build
# バックエンドのビルド
cd ..
cd ../..
go mod download
go build -ldflags "-s -w" -o one-api
```
@@ -242,17 +242,18 @@ graph LR
+ 例: `SESSION_SECRET=random_string`
3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。
+ 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる
4. `LOG_SQL_DSN`: 設定ると、`logs`テーブルには独立したデータベースが使用されます。MySQLまたはPostgreSQLを使用してください
5. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。
+ 例: `FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。
6. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。
+ 例: `SYNC_FREQUENCY=60`
6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master``slave` である。設定されていない場合、デフォルトは `master`
7. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master``slave` である。設定されていない場合、デフォルトは `master`
+ 例: `NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。
8. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。
+ 例: `CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。
9. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。
+ 例: `CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。
10. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。
+ 例: `POLLING_INTERVAL=5`
### コマンドラインパラメータ

View File

@@ -67,19 +67,27 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
+ [x] [Mistral 系列模型](https://mistral.ai/)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
+ [x] [360 智脑](https://ai.360.cn)
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
+ [x] [Moonshot AI](https://platform.moonshot.cn/)
+ [x] [百川大模型](https://platform.baichuan-ai.com)
+ [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
+ [x] [MINIMAX](https://api.minimax.chat/)
+ [x] [Groq](https://wow.groq.com/)
+ [x] [Ollama](https://github.com/ollama/ollama)
+ [x] [零一万物](https://platform.lingyiwanwu.com/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
5. 支持**多机部署**[详见此处](#多机部署)。
6. 支持**令牌管理**,设置令牌的过期时间和额度。
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
8. 支持**道管理**,批量创建道。
8. 支持**道管理**,批量创建道。
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
10. 支持渠道**设置模型列表**。
11. 支持**查看额度明细**。
@@ -100,6 +108,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
## 部署
### 基于 Docker 进行部署
@@ -174,12 +183,12 @@ docker-compose ps
git clone https://github.com/songquanpeng/one-api.git
# 构建前端
cd one-api/web
cd one-api/web/default
npm install
npm run build
# 构建后端
cd ..
cd ../..
go mod download
go build -ldflags "-s -w" -o one-api
````
@@ -340,35 +349,40 @@ graph LR
+ `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。
+ 如果报错 `Error 1040: Too many connections`,请适当减小该值。
+ `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置
4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL
5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`MEMORY_CACHE_ENABLED=true`
6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
+ 例子:`SYNC_FREQUENCY=60`
7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
+ 例子:`NODE_TYPE=slave`
8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5`
13. 请求频率限制:
14. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
14. 编码器缓存设置:
15. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
16. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
17. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
17. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
18. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
19. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
20. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
21. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
22. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
23. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
@@ -407,13 +421,16 @@ https://openai.justsong.cn
+ 检查你的接口地址和 API Key 有没有填对。
+ 检查是否启用了 HTTPS浏览器会拦截 HTTPS 域名下的 HTTP 请求。
6. 报错:`当前分组负载已饱和,请稍后再试`
+ 上游道 429 了。
+ 上游道 429 了。
7. 升级之后我的数据会丢失吗?
+ 如果使用 MySQL不会。
+ 如果使用 SQLite需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
8. 升级之前数据库需要做变更吗?
+ 一般情况下不需要,系统将在初始化的时候自动调整。
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`
+ 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。
+ 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。
## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统

29
common/blacklist/main.go Normal file
View File

@@ -0,0 +1,29 @@
package blacklist
import (
"fmt"
"sync"
)
var blackList sync.Map
func init() {
blackList = sync.Map{}
}
func userId2Key(id int) string {
return fmt.Sprintf("userid_%d", id)
}
func BanUser(id int) {
blackList.Store(userId2Key(id), true)
}
func UnbanUser(id int) {
blackList.Delete(userId2Key(id))
}
func IsUserBanned(id int) bool {
_, ok := blackList.Load(userId2Key(id))
return ok
}

140
common/config/config.go Normal file
View File

@@ -0,0 +1,140 @@
package config
import (
"github.com/songquanpeng/one-api/common/env"
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var MessagePusherAddress = ""
var MessagePusherToken = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser int64 = 0
var QuotaForInviter int64 = 0
var QuotaForInvitee int64 = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold int64 = 1000
var PreConsumedQuota int64 = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = env.String("THEME", "default")
var ValidThemes = map[string]bool{
"default": true,
"berry": true,
"air": true,
}
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
var EnableMetric = env.Bool("ENABLE_METRIC", false)
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")

View File

@@ -1,110 +1,9 @@
package common
import (
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
import "time"
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser = 0
var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = GetOrDefaultString("THEME", "default")
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
const (
RoleGuestUser = 0
@@ -113,37 +12,10 @@ const (
RoleRootUser = 100
)
var (
FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser
ImageUploadPermission = RoleGuestUser
ImageDownloadPermission = RoleGuestUser
)
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
)
const (
@@ -167,57 +39,81 @@ const (
)
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeAPI2D = 2
ChannelTypeAzure = 3
ChannelTypeCloseAI = 4
ChannelTypeOpenAISB = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeUnknown = iota
ChannelTypeOpenAI
ChannelTypeAPI2D
ChannelTypeAzure
ChannelTypeCloseAI
ChannelTypeOpenAISB
ChannelTypeOpenAIMax
ChannelTypeOhMyGPT
ChannelTypeCustom
ChannelTypeAILS
ChannelTypeAIProxy
ChannelTypePaLM
ChannelTypeAPI2GPT
ChannelTypeAIGC2D
ChannelTypeAnthropic
ChannelTypeBaidu
ChannelTypeZhipu
ChannelTypeAli
ChannelTypeXunfei
ChannelType360
ChannelTypeOpenRouter
ChannelTypeAIProxyLibrary
ChannelTypeFastGPT
ChannelTypeTencent
ChannelTypeGemini
ChannelTypeMoonshot
ChannelTypeBaichuan
ChannelTypeMinimax
ChannelTypeMistral
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
ChannelTypeDummy
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"", //24
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"https://generativelanguage.googleapis.com", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
"https://api.moonshot.cn", // 25
"https://api.baichuan-ai.com", // 26
"https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
"https://api.lingyiwanwu.com", // 31
}
const (
ConfigKeyPrefix = "cfg_"
ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
)

6
common/conv/any.go Normal file
View File

@@ -0,0 +1,6 @@
package conv
func AsString(v any) string {
str, _ := v.(string)
return str
}

View File

@@ -1,7 +1,12 @@
package common
import (
"github.com/songquanpeng/one-api/common/env"
)
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)
var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)

View File

@@ -15,10 +15,7 @@ type embedFileSystem struct {
func (e embedFileSystem) Exists(prefix string, path string) bool {
_, err := e.Open(path)
if err != nil {
return false
}
return true
return err == nil
}
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {

42
common/env/helper.go vendored Normal file
View File

@@ -0,0 +1,42 @@
package env
import (
"os"
"strconv"
)
func Bool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env) == "true"
}
func Int(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
return defaultValue
}
return num
}
func Float64(env string, defaultValue float64) float64 {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.ParseFloat(os.Getenv(env), 64)
if err != nil {
return defaultValue
}
return num
}
func String(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}

View File

@@ -8,12 +8,24 @@ import (
"strings"
)
func UnmarshalBodyReusable(c *gin.Context, v any) error {
const KeyRequestBody = "key_request_body"
func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(KeyRequestBody)
if requestBody != nil {
return requestBody.([]byte), nil
}
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return err
return nil, err
}
err = c.Request.Body.Close()
_ = c.Request.Body.Close()
c.Set(KeyRequestBody, requestBody)
return requestBody.([]byte), nil
}
func UnmarshalBodyReusable(c *gin.Context, v any) error {
requestBody, err := GetRequestBody(c)
if err != nil {
return err
}
@@ -31,3 +43,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

View File

@@ -1,6 +1,9 @@
package common
import "encoding/json"
import (
"encoding/json"
"github.com/songquanpeng/one-api/common/logger"
)
var GroupRatio = map[string]float64{
"default": 1,
@@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error {
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
if !ok {
SysError("group ratio not found: " + name)
logger.SysError("group ratio not found: " + name)
return 1
}
return ratio

217
common/helper/helper.go Normal file
View File

@@ -0,0 +1,217 @@
package helper
import (
"fmt"
"github.com/google/uuid"
"html/template"
"log"
"math/rand"
"net"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
)
func OpenBrowser(url string) {
var err error
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
} else {
numStr = fmt.Sprintf("%d", num)
}
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter := inter.(type) {
case string:
return inter
case int:
return fmt.Sprintf("%d", inter)
case float64:
return fmt.Sprintf("%f", inter)
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func GenRequestID() string {
return GetTimeString() + GetRandomNumberString(8)
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func AssignOrDefault(value string, defaultValue string) string {
if len(value) != 0 {
return value
}
return defaultValue
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@@ -12,7 +12,7 @@ import (
"strings"
"testing"
img "one-api/common/image"
img "github.com/songquanpeng/one-api/common/image"
"github.com/stretchr/testify/assert"
_ "golang.org/x/image/webp"

View File

@@ -3,6 +3,8 @@ package common
import (
"flag"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"log"
"os"
"path/filepath"
@@ -37,9 +39,9 @@ func init() {
if os.Getenv("SESSION_SECRET") != "" {
if os.Getenv("SESSION_SECRET") == "random_string" {
SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
SessionSecret = os.Getenv("SESSION_SECRET")
config.SessionSecret = os.Getenv("SESSION_SECRET")
}
}
if os.Getenv("SQLITE_PATH") != "" {
@@ -57,5 +59,6 @@ func init() {
log.Fatal(err)
}
}
logger.LogDir = *LogDir
}
}

View File

@@ -0,0 +1,7 @@
package logger
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
var LogDir string

View File

@@ -1,9 +1,11 @@
package common
package logger
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"io"
"log"
"os"
@@ -13,19 +15,17 @@ import (
)
const (
loggerDEBUG = "DEBUG"
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
)
const maxLogCount = 1000000
var logCount int
var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
if *LogDir != "" {
if LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
@@ -35,7 +35,7 @@ func SetupLogger() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
@@ -55,29 +55,52 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func LogInfo(ctx context.Context, msg string) {
func Debug(ctx context.Context, msg string) {
if config.DebugEnabled {
logHelper(ctx, loggerDEBUG, msg)
}
}
func Info(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
func LogWarn(ctx context.Context, msg string) {
func Warn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
func LogError(ctx context.Context, msg string) {
func Error(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
func Debugf(ctx context.Context, format string, a ...any) {
Debug(ctx, fmt.Sprintf(format, a...))
}
func Infof(ctx context.Context, format string, a ...any) {
Info(ctx, fmt.Sprintf(format, a...))
}
func Warnf(ctx context.Context, format string, a ...any) {
Warn(ctx, fmt.Sprintf(format, a...))
}
func Errorf(ctx context.Context, format string, a ...any) {
Error(ctx, fmt.Sprintf(format, a...))
}
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
if id == nil {
id = helper.GenRequestID()
}
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
if !setupLogWorking {
setupLogWorking = true
go func() {
SetupLogger()
@@ -90,11 +113,3 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
func LogQuota(quota int) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", quota)
}
}

View File

@@ -1,23 +1,27 @@
package common
package message
import (
"crypto/rand"
"crypto/tls"
"encoding/base64"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"net/smtp"
"strings"
"time"
)
func SendEmail(subject string, receiver string, content string) error {
if SMTPFrom == "" { // for compatibility
SMTPFrom = SMTPAccount
if receiver == "" {
return fmt.Errorf("receiver is empty")
}
if config.SMTPFrom == "" { // for compatibility
config.SMTPFrom = config.SMTPAccount
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
parts := strings.Split(SMTPFrom, "@")
parts := strings.Split(config.SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
@@ -36,21 +40,21 @@ func SendEmail(subject string, receiver string, content string) error {
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";")
if SMTPPort == 465 {
if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: SMTPServer,
ServerName: config.SMTPServer,
}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
if err != nil {
return err
}
client, err := smtp.NewClient(conn, SMTPServer)
client, err := smtp.NewClient(conn, config.SMTPServer)
if err != nil {
return err
}
@@ -58,7 +62,7 @@ func SendEmail(subject string, receiver string, content string) error {
if err = client.Auth(auth); err != nil {
return err
}
if err = client.Mail(SMTPFrom); err != nil {
if err = client.Mail(config.SMTPFrom); err != nil {
return err
}
receiverEmails := strings.Split(receiver, ";")
@@ -80,7 +84,7 @@ func SendEmail(subject string, receiver string, content string) error {
return err
}
} else {
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail)
}
return err
}

22
common/message/main.go Normal file
View File

@@ -0,0 +1,22 @@
package message
import (
"fmt"
"github.com/songquanpeng/one-api/common/config"
)
const (
ByAll = "all"
ByEmail = "email"
ByMessagePusher = "message_pusher"
)
func Notify(by string, title string, description string, content string) error {
if by == ByEmail {
return SendEmail(title, config.RootUserEmail, content)
}
if by == ByMessagePusher {
return SendMessage(title, description, content)
}
return fmt.Errorf("unknown notify method: %s", by)
}

View File

@@ -0,0 +1,53 @@
package message
import (
"bytes"
"encoding/json"
"errors"
"github.com/songquanpeng/one-api/common/config"
"net/http"
)
type request struct {
Title string `json:"title"`
Description string `json:"description"`
Content string `json:"content"`
URL string `json:"url"`
Channel string `json:"channel"`
Token string `json:"token"`
}
type response struct {
Success bool `json:"success"`
Message string `json:"message"`
}
func SendMessage(title string, description string, content string) error {
if config.MessagePusherAddress == "" {
return errors.New("message pusher address is not set")
}
req := request{
Title: title,
Description: description,
Content: content,
Token: config.MessagePusherToken,
}
data, err := json.Marshal(req)
if err != nil {
return err
}
resp, err := http.Post(config.MessagePusherAddress,
"application/json", bytes.NewBuffer(data))
if err != nil {
return err
}
var res response
err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil {
return err
}
if !res.Success {
return errors.New(res.Message)
}
return nil
}

View File

@@ -2,91 +2,95 @@ package common
import (
"encoding/json"
"github.com/songquanpeng/one-api/common/logger"
"strings"
"time"
)
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}
const (
USD2RMB = 7
USD = 500 // $0.002 = 1 -> $1 = 500
RMB = USD / USD2RMB
)
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
// https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens
"claude-2.0": 5.51, // $11.02 / 1M tokens
"claude-2.1": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // 0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // 0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
// https://openai.com/pricing
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-ada-002": 0.05,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
// https://www.anthropic.com/api#pricing
"claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD,
"claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240307": 0.25 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
"ERNIE-Bot-8k": 0.024 * RMB,
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"bge-large-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
// https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
"glm-3-turbo": 0.005 * RMB,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
@@ -97,17 +101,87 @@ var ModelRatio = map[string]float64{
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"ChatStd": 0.01 * RMB,
"ChatPro": 0.1 * RMB,
// https://platform.moonshot.cn/pricing
"moonshot-v1-8k": 0.012 * RMB,
"moonshot-v1-32k": 0.024 * RMB,
"moonshot-v1-128k": 0.06 * RMB,
// https://platform.baichuan-ai.com/price
"Baichuan2-Turbo": 0.008 * RMB,
"Baichuan2-Turbo-192k": 0.016 * RMB,
"Baichuan2-53B": 0.02 * RMB,
// https://api.minimax.chat/document/price
"abab6-chat": 0.1 * RMB,
"abab5.5-chat": 0.015 * RMB,
"abab5.5s-chat": 0.005 * RMB,
// https://docs.mistral.ai/platform/pricing/
"open-mistral-7b": 0.25 / 1000 * USD,
"open-mixtral-8x7b": 0.7 / 1000 * USD,
"mistral-small-latest": 2.0 / 1000 * USD,
"mistral-medium-latest": 2.7 / 1000 * USD,
"mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD,
// https://wow.groq.com/
"llama2-70b-4096": 0.7 / 1000 * USD,
"llama2-7b-2048": 0.1 / 1000 * USD,
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD,
// https://platform.lingyiwanwu.com/docs#-计费单元
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
"yi-vl-plus": 6.0 / 1000 * RMB,
}
var CompletionRatio = map[string]float64{}
var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64
func init() {
DefaultModelRatio = make(map[string]float64)
for k, v := range ModelRatio {
DefaultModelRatio[k] = v
}
DefaultCompletionRatio = make(map[string]float64)
for k, v := range CompletionRatio {
DefaultCompletionRatio[k] = v
}
}
func AddNewMissingRatio(oldRatio string) string {
newRatio := make(map[string]float64)
err := json.Unmarshal([]byte(oldRatio), &newRatio)
if err != nil {
logger.SysError("error unmarshalling old ratio: " + err.Error())
return oldRatio
}
for k, v := range DefaultModelRatio {
if _, ok := newRatio[k]; !ok {
newRatio[k] = v
}
}
jsonBytes, err := json.Marshal(newRatio)
if err != nil {
logger.SysError("error marshalling new ratio: " + err.Error())
return oldRatio
}
return string(jsonBytes)
}
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
logger.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
@@ -123,27 +197,45 @@ func GetModelRatio(name string) float64 {
}
ratio, ok := ModelRatio[name]
if !ok {
SysError("model ratio not found: " + name)
ratio, ok = DefaultModelRatio[name]
}
if !ok {
logger.SysError("model ratio not found: " + name)
return 30
}
return ratio
}
func CompletionRatio2JSONString() string {
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
logger.SysError("error marshalling completion ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}
if ratio, ok := DefaultCompletionRatio[name]; ok {
return ratio
}
if strings.HasPrefix(name, "gpt-3.5") {
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
return 3
}
if strings.HasSuffix(name, "1106") {
return 2
}
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
// TODO: clear this after 2023-12-11
now := time.Now()
// https://platform.openai.com/docs/models/continuous-model-upgrades
// if after 2023-12-11, use 2
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
return 2
}
}
return 1.333333
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
@@ -151,11 +243,21 @@ func GetCompletionRatio(name string) float64 {
}
return 2
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3.38
if strings.HasPrefix(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "claude-2") {
return 2.965517
if strings.HasPrefix(name, "claude-") {
return 3
}
if strings.HasPrefix(name, "mistral-") {
return 3
}
if strings.HasPrefix(name, "gemini-") {
return 3
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7
}
return 1
}

8
common/random.go Normal file
View File

@@ -0,0 +1,8 @@
package common
import "math/rand"
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

View File

@@ -3,6 +3,7 @@ package common
import (
"context"
"github.com/go-redis/redis/v8"
"github.com/songquanpeng/one-api/common/logger"
"os"
"time"
)
@@ -14,18 +15,18 @@ var RedisEnabled = true
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {
RedisEnabled = false
SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
return nil
}
if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false
SysLog("SYNC_FREQUENCY not set, Redis is disabled")
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
}
SysLog("Redis is enabled")
logger.SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
RDB = redis.NewClient(opt)
@@ -34,7 +35,7 @@ func InitRedisClient() (err error) {
_, err = RDB.Ping(ctx).Result()
if err != nil {
FatalLog("Redis ping test failed: " + err.Error())
logger.FatalLog("Redis ping test failed: " + err.Error())
}
return err
}
@@ -42,7 +43,7 @@ func InitRedisClient() (err error) {
func ParseRedisOption() *redis.Options {
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
return opt
}

View File

@@ -2,215 +2,13 @@ package common
import (
"fmt"
"github.com/google/uuid"
"html/template"
"log"
"math/rand"
"net"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
"github.com/songquanpeng/one-api/common/config"
)
func OpenBrowser(url string) {
var err error
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
func LogQuota(quota int64) string {
if config.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/config.QuotaPerUnit)
} else {
numStr = fmt.Sprintf("%d", num)
}
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter.(type) {
case string:
return inter.(string)
case int:
return fmt.Sprintf("%d", inter.(int))
case float64:
return fmt.Sprintf("%f", inter.(float64))
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
return fmt.Sprintf("%d 点额度", quota)
}
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@@ -2,17 +2,18 @@ package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/model"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func GetSubscription(c *gin.Context) {
var remainQuota int
var usedQuota int
var remainQuota int64
var usedQuota int64
var err error
var token *model.Token
var expiredTime int64
if common.DisplayTokenStatEnabled {
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
@@ -21,25 +22,27 @@ func GetSubscription(c *gin.Context) {
} else {
userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId)
usedQuota, err = model.GetUserUsedQuota(userId)
if err != nil {
usedQuota, err = model.GetUserUsedQuota(userId)
}
}
if expiredTime <= 0 {
expiredTime = 0
}
if err != nil {
openAIError := OpenAIError{
Error := relaymodel.Error{
Message: err.Error(),
Type: "upstream_error",
}
c.JSON(200, gin.H{
"error": openAIError,
"error": Error,
})
return
}
quota := remainQuota + usedQuota
amount := float64(quota)
if common.DisplayInCurrencyEnabled {
amount /= common.QuotaPerUnit
if config.DisplayInCurrencyEnabled {
amount /= config.QuotaPerUnit
}
if token != nil && token.UnlimitedQuota {
amount = 100000000
@@ -57,10 +60,10 @@ func GetSubscription(c *gin.Context) {
}
func GetUsage(c *gin.Context) {
var quota int
var quota int64
var err error
var token *model.Token
if common.DisplayTokenStatEnabled {
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota
@@ -69,18 +72,18 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId)
}
if err != nil {
openAIError := OpenAIError{
Error := relaymodel.Error{
Message: err.Error(),
Type: "one_api_error",
}
c.JSON(200, gin.H{
"error": openAIError,
"error": Error,
})
return
}
amount := float64(quota)
if common.DisplayInCurrencyEnabled {
amount /= common.QuotaPerUnit
if config.DisplayInCurrencyEnabled {
amount /= config.QuotaPerUnit
}
usage := OpenAIUsageResponse{
Object: "list",

View File

@@ -4,10 +4,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
@@ -92,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
res, err := httpClient.Do(req)
res, err := util.HTTPClient.Do(req)
if err != nil {
return nil, err
}
@@ -292,7 +296,7 @@ func UpdateChannelBalance(c *gin.Context) {
}
func updateAllChannelsBalance() error {
channels, err := model.GetAllChannels(0, 0, true)
channels, err := model.GetAllChannels(0, 0, "all")
if err != nil {
return err
}
@@ -310,24 +314,23 @@ func updateAllChannelsBalance() error {
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
disableChannel(channel.Id, channel.Name, "余额不足")
monitor.DisableChannel(channel.Id, channel.Name, "余额不足")
}
}
time.Sleep(common.RequestInterval)
time.Sleep(config.RequestInterval)
}
return nil
}
func UpdateAllChannelsBalance(c *gin.Context) {
// TODO: make it async
err := updateAllChannelsBalance()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
//err := updateAllChannelsBalance()
//if err != nil {
// c.JSON(http.StatusOK, gin.H{
// "success": false,
// "message": err.Error(),
// })
// return
//}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -338,8 +341,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
func AutomaticallyUpdateChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("updating all channels")
logger.SysLog("updating all channels")
_ = updateAllChannelsBalance()
common.SysLog("channels update done")
logger.SysLog("channels update done")
}
}

View File

@@ -5,98 +5,36 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"one-api/common"
"one-api/model"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
switch channel.Type {
case common.ChannelTypePaLM:
fallthrough
case common.ChannelTypeGemini:
fallthrough
case common.ChannelTypeAnthropic:
fallthrough
case common.ChannelTypeBaidu:
fallthrough
case common.ChannelTypeZhipu:
fallthrough
case common.ChannelTypeAli:
fallthrough
case common.ChannelType360:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo"
defer func() {
if err != nil {
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
}
}()
default:
request.Model = "gpt-3.5-turbo"
func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 2,
Stream: false,
Model: "gpt-3.5-turbo",
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
} else {
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
requestURL = baseURL
}
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
}
jsonData, err := json.Marshal(request)
if err != nil {
return err, nil
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err, nil
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return err, nil
}
defer resp.Body.Close()
var response TextResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
return err, nil
}
err = json.Unmarshal(body, &response)
if err != nil {
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
}
if response.Usage.CompletionTokens == 0 {
if response.Error.Message == "" {
response.Error.Message = "补全 tokens 非预期返回 0"
}
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
return nil, nil
}
func buildTestRequest() *ChatRequest {
testRequest := &ChatRequest{
Model: "", // this will be set later
MaxTokens: 1,
}
testMessage := Message{
testMessage := relaymodel.Message{
Role: "user",
Content: "hi",
}
@@ -104,6 +42,72 @@ func buildTestRequest() *ChatRequest {
return testRequest
}
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
Body: nil,
Header: make(http.Header),
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, "")
meta := util.GetRelayMeta(c)
apiType := constant.ChannelType2APIType(channel.Type)
adaptor := helper.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
adaptor.Init(meta)
modelName := adaptor.GetModelList()[0]
if !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 {
modelName = modelNames[0]
}
}
request := buildTestRequest()
request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
if err != nil {
return err, nil
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return err, nil
}
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
return err, nil
}
if resp.StatusCode != http.StatusOK {
err := util.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
}
if usage == nil {
return errors.New("usage is nil"), nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
}
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -121,9 +125,8 @@ func TestChannel(c *gin.Context) {
})
return
}
testRequest := buildTestRequest()
tik := time.Now()
err, _ = testChannel(channel, *testRequest)
err, _ = testChannel(channel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
@@ -147,35 +150,9 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
// enable & notify
func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
func testChannels(notify bool, scope string) error {
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
@@ -184,12 +161,11 @@ func testAllChannels(notify bool) error {
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true)
channels, err := model.GetAllChannels(0, 0, scope)
if err != nil {
return err
}
testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
@@ -197,37 +173,45 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
err, openaiErr := testChannel(channel, *testRequest)
err, openaiErr := testChannel(channel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
disableChannel(channel.Id, channel.Name, err.Error())
if config.AutomaticDisableChannelEnabled {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} else {
_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s %d测试超时", channel.Name, channel.Id), "", err.Error())
}
}
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
disableChannel(channel.Id, channel.Name, err.Error())
if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
monitor.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
time.Sleep(config.RequestInterval)
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
err := common.SendEmail("道测试完成", common.RootUserEmail, "道测试完成,如果没有收到禁用通知,说明所有道都正常")
err := message.Notify(message.ByAll, "道测试完成", "", "道测试完成,如果没有收到禁用通知,说明所有道都正常")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
}()
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(true)
func TestChannels(c *gin.Context) {
scope := c.Query("scope")
if scope == "" {
scope = "all"
}
err := testChannels(true, scope)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -245,8 +229,8 @@ func TestAllChannels(c *gin.Context) {
func AutomaticallyTestChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("testing all channels")
_ = testAllChannels(false)
common.SysLog("channel test finished")
logger.SysLog("testing all channels")
_ = testChannels(false, "all")
logger.SysLog("channel test finished")
}
}

View File

@@ -2,9 +2,10 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
@@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) {
if p < 0 {
p = 0
}
channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) {
})
return
}
channel.CreatedTime = common.GetTimestamp()
channel.CreatedTime = helper.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {

View File

@@ -7,9 +7,12 @@ import (
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
@@ -30,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
@@ -46,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res.Body.Close()
@@ -62,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res2.Body.Close()
@@ -93,7 +96,7 @@ func GitHubOAuth(c *gin.Context) {
return
}
if !common.GitHubOAuthEnabled {
if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
@@ -122,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
return
}
} else {
if common.RegisterEnabled {
if config.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
@@ -160,7 +163,7 @@ func GitHubOAuth(c *gin.Context) {
}
func GitHubBind(c *gin.Context) {
if !common.GitHubOAuthEnabled {
if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
@@ -216,7 +219,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := common.GetRandomString(12)
state := helper.GetRandomString(12)
session.Set("oauth_state", state)
err := session.Save()
if err != nil {

View File

@@ -2,13 +2,13 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"net/http"
"one-api/common"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio {
for groupName := range common.GroupRatio {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{

View File

@@ -2,9 +2,9 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
)
@@ -20,7 +20,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -47,7 +47,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -3,9 +3,11 @@ package controller
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
@@ -18,55 +20,55 @@ func GetStatus(c *gin.Context) {
"data": gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"chat_link": common.ChatLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
"wechat_login": config.WeChatAuthEnabled,
"server_address": config.ServerAddress,
"turnstile_check": config.TurnstileCheckEnabled,
"turnstile_site_key": config.TurnstileSiteKey,
"top_up_link": config.TopUpLink,
"chat_link": config.ChatLink,
"quota_per_unit": config.QuotaPerUnit,
"display_in_currency": config.DisplayInCurrencyEnabled,
},
})
return
}
func GetNotice(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()
config.OptionMapRWMutex.RLock()
defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": common.OptionMap["Notice"],
"data": config.OptionMap["Notice"],
})
return
}
func GetAbout(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()
config.OptionMapRWMutex.RLock()
defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": common.OptionMap["About"],
"data": config.OptionMap["About"],
})
return
}
func GetHomePageContent(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()
config.OptionMapRWMutex.RLock()
defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": common.OptionMap["HomePageContent"],
"data": config.OptionMap["HomePageContent"],
})
return
}
@@ -80,9 +82,9 @@ func SendEmailVerification(c *gin.Context) {
})
return
}
if common.EmailDomainRestrictionEnabled {
if config.EmailDomainRestrictionEnabled {
allowed := false
for _, domain := range common.EmailDomainWhitelist {
for _, domain := range config.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) {
allowed = true
break
@@ -105,11 +107,11 @@ func SendEmailVerification(c *gin.Context) {
}
code := common.GenerateVerificationCode(6)
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+
"<p>您的验证码为: <strong>%s</strong></p>"+
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes)
err := message.SendEmail(subject, email, content)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -142,13 +144,13 @@ func SendPasswordResetEmail(c *gin.Context) {
}
code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", config.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes)
err := message.SendEmail(subject, email, content)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -2,8 +2,14 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -35,6 +41,7 @@ type OpenAIModels struct {
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func init() {
var permission []OpenAIModelPermission
@@ -53,552 +60,63 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
Id: "dall-e-2",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-2",
Parent: nil,
},
{
Id: "dall-e-3",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-3",
Parent: nil,
},
{
Id: "whisper-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "whisper-1",
Parent: nil,
},
{
Id: "tts-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1",
Parent: nil,
},
{
Id: "tts-1-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-1106",
Parent: nil,
},
{
Id: "tts-1-hd",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd",
Parent: nil,
},
{
Id: "tts-1-hd-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd-1106",
Parent: nil,
},
{
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0301",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0301",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-1106",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-1106",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-instruct",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-instruct",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4",
Parent: nil,
},
{
Id: "gpt-4-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0314",
Parent: nil,
},
{
Id: "gpt-4-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0613",
Parent: nil,
},
{
Id: "gpt-4-32k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k",
Parent: nil,
},
{
Id: "gpt-4-32k-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-4-32k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0613",
Parent: nil,
},
{
Id: "gpt-4-1106-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-1106-preview",
Parent: nil,
},
{
Id: "gpt-4-vision-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-vision-preview",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
{
Id: "text-davinci-003",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-003",
Parent: nil,
},
{
Id: "text-davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-002",
Parent: nil,
},
{
Id: "text-curie-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-curie-001",
Parent: nil,
},
{
Id: "text-babbage-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-babbage-001",
Parent: nil,
},
{
Id: "text-ada-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-ada-001",
Parent: nil,
},
{
Id: "text-moderation-latest",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-latest",
Parent: nil,
},
{
Id: "text-moderation-stable",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-stable",
Parent: nil,
},
{
Id: "text-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-edit-001",
Parent: nil,
},
{
Id: "code-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "code-davinci-edit-001",
Parent: nil,
},
{
Id: "davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "davinci-002",
Parent: nil,
},
{
Id: "babbage-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "babbage-002",
Parent: nil,
},
{
Id: "claude-instant-1",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-instant-1",
Parent: nil,
},
{
Id: "claude-2",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2",
Parent: nil,
},
{
Id: "claude-2.1",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.1",
Parent: nil,
},
{
Id: "claude-2.0",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.0",
Parent: nil,
},
{
Id: "ERNIE-Bot",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot",
Parent: nil,
},
{
Id: "ERNIE-Bot-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "ERNIE-Bot-4",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-4",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "Embedding-V1",
Parent: nil,
},
{
Id: "PaLM-2",
Object: "model",
Created: 1677649963,
OwnedBy: "google palm",
Permission: permission,
Root: "PaLM-2",
Parent: nil,
},
{
Id: "gemini-pro",
Object: "model",
Created: 1677649963,
OwnedBy: "google gemini",
Permission: permission,
Root: "gemini-pro",
Parent: nil,
},
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google gemini",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_turbo",
Parent: nil,
},
{
Id: "chatglm_pro",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_pro",
Parent: nil,
},
{
Id: "chatglm_std",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_std",
Parent: nil,
},
{
Id: "chatglm_lite",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "qwen-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-turbo",
Parent: nil,
},
{
Id: "qwen-plus",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus",
Parent: nil,
},
{
Id: "qwen-max",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-max",
Parent: nil,
},
{
Id: "qwen-max-longcontext",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-max-longcontext",
Parent: nil,
},
{
Id: "text-embedding-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "text-embedding-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",
Created: 1677649963,
OwnedBy: "xunfei",
Permission: permission,
Root: "SparkDesk",
Parent: nil,
},
{
Id: "360GPT_S2_V9",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "360GPT_S2_V9",
Parent: nil,
},
{
Id: "embedding-bert-512-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding-bert-512-v1",
Parent: nil,
},
{
Id: "embedding_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding_s1_v1",
Parent: nil,
},
{
Id: "semantic_similarity_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "semantic_similarity_s1_v1",
Parent: nil,
},
{
Id: "hunyuan",
Object: "model",
Created: 1677649963,
OwnedBy: "tencent",
Permission: permission,
Root: "hunyuan",
Parent: nil,
},
for i := 0; i < constant.APITypeDummy; i++ {
if i == constant.APITypeAIProxyLibrary {
continue
}
adaptor := helper.GetAdaptor(i)
channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
for _, channelType := range openai.CompatibleChannels {
if channelType == common.ChannelTypeAzure {
continue
}
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
}
channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ {
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i))
meta := &util.RelayMeta{
ChannelType: i,
}
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
}
func DashboardListModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channelId2Models,
})
}
func ListModels(c *gin.Context) {
@@ -613,14 +131,14 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
openAIError := OpenAIError{
Error := relaymodel.Error{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",
Code: "model_not_found",
}
c.JSON(200, gin.H{
"error": openAIError,
"error": Error,
})
}
}

View File

@@ -2,9 +2,10 @@ package controller
import (
"encoding/json"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
@@ -12,17 +13,17 @@ import (
func GetOptions(c *gin.Context) {
var options []*model.Option
common.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap {
config.OptionMapRWMutex.Lock()
for k, v := range config.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue
}
options = append(options, &model.Option{
Key: k,
Value: common.Interface2String(v),
Value: helper.Interface2String(v),
})
}
common.OptionMapRWMutex.Unlock()
config.OptionMapRWMutex.Unlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -42,8 +43,16 @@ func UpdateOption(c *gin.Context) {
return
}
switch option.Key {
case "Theme":
if !config.ValidThemes[option.Value] {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的主题",
})
return
}
case "GitHubOAuthEnabled":
if option.Value == "true" && common.GitHubClientId == "" {
if option.Value == "true" && config.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret",
@@ -51,7 +60,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
@@ -59,7 +68,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "WeChatAuthEnabled":
if option.Value == "true" && common.WeChatServerAddress == "" {
if option.Value == "true" && config.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用微信登录,请先填入微信登录相关配置信息!",
@@ -67,7 +76,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "TurnstileCheckEnabled":
if option.Value == "true" && common.TurnstileSiteKey == "" {
if option.Value == "true" && config.TurnstileSiteKey == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",

View File

@@ -2,9 +2,10 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
)
@@ -13,7 +14,7 @@ func GetAllRedemptions(c *gin.Context) {
if p < 0 {
p = 0
}
redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage)
redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -105,12 +106,12 @@ func AddRedemption(c *gin.Context) {
}
var keys []string
for i := 0; i < redemption.Count; i++ {
key := common.GetUUID()
key := helper.GetUUID()
cleanRedemption := model.Redemption{
UserId: c.GetInt("id"),
Name: redemption.Name,
Key: key,
CreatedTime: common.GetTimestamp(),
CreatedTime: helper.GetTimestamp(),
Quota: redemption.Quota,
}
err = cleanRedemption.Insert()

View File

@@ -1,322 +0,0 @@
package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
type AliMessage struct {
Content string `json:"content"`
Role string `json:"role"`
}
type AliInput struct {
//Prompt string `json:"prompt"`
Messages []AliMessage `json:"messages"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
const AliEnableSearchModelSuffix = "-internet"
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
}
return &AliChatRequest{
Model: aliModel,
Input: AliInput{
Messages: messages,
},
Parameters: AliParameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
},
}
}
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "qwen",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
//lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
fullTextResponse.Model = "qwen"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -1,223 +0,0 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
}
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return reason
}
}
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
}
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse)
response.Id = responseId
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = model
completionTokens := countTokenText(claudeResponse.Completion, model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -1,222 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
func isWithinRange(element string, value int) bool {
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
return false
}
min := common.DalleGenerationImageAmounts[element][0]
max := common.DalleGenerationImageAmounts[element][1]
return value >= min && value <= max
}
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
imageModel := "dall-e-2"
imageSize := "1024x1024"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
var imageRequest ImageRequest
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
// Size validation
if imageRequest.Size != "" {
imageSize = imageRequest.Size
}
// Model validation
if imageRequest.Model != "" {
imageModel = imageRequest.Model
}
imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
// Check if model is supported
if hasValidSize {
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
if imageSize == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
} else {
return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
// Prompt validation
if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
// Check prompt length
if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if isWithinRange(imageModel, imageRequest.N) == false {
// channel not azure
if channelType != common.ChannelTypeAzure {
return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := GetAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
}
var requestBody io.Reader
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
if userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
token := c.Request.Header.Get("Authorization")
if channelType == common.ChannelTypeAzure { // Azure authentication
token = strings.TrimPrefix(token, "Bearer ")
req.Header.Set("api-key", token)
} else {
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var textResponse ImageResponse
defer func(ctx context.Context) {
if resp.StatusCode != http.StatusOK {
return
}
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -1,288 +0,0 @@
package controller
import (
"bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"sort"
"strconv"
"strings"
)
// https://cloud.tencent.com/document/product/1729/97732
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, TencentMessage{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, TencentMessage{
Content: message.StringContent(),
Role: message.Role,
})
}
stream := 0
if request.Stream {
stream = 1
}
return &TencentChatRequest{
Timestamp: common.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages,
}
}
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: response.Usage,
}
if len(response.Choices) > 0 {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Choices[0].Messages.Content,
},
FinishReason: response.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &stopFinishReason
}
response.Choices = append(response.Choices, choice)
}
return &response
}
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var TencentResponse TencentChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &TencentResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
fullTextResponse.Model = "hunyuan"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
parts := strings.Split(config, "|")
if len(parts) != 3 {
err = errors.New("invalid tencent config")
return
}
appId, err = strconv.ParseInt(parts[0], 10, 64)
secretId = parts[1]
secretKey = parts[2]
return
}
func getTencentSign(req TencentChatRequest, secretKey string) string {
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
}

View File

@@ -1,689 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const (
APITypeOpenAI = iota
APITypeClaude
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
)
var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
}
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
var textRequest GeneralOpenAIRequest
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
switch relayMode {
case RelayModeCompletions:
if textRequest.Prompt == "" {
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEmbeddings:
case RelayModeModerations:
if textRequest.Input == "" {
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEdits:
if textRequest.Instruction == "" {
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
isModelMapped = true
}
}
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeClaude
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
apiVersion := GetAPIVersion(c)
requestURL := strings.Split(requestURL, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
baseURL = c.GetString("base_url")
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := textRequest.Model
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
}
case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
}
case APITypeBaidu:
switch textRequest.Model {
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
var err error
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
}
fullRequestURL += "?access_token=" + apiKey
case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
requestBaseURL = baseURL
}
version := "v1"
if c.GetString("api_version") != "" {
version = c.GetString("api_version")
}
action := "generateContent"
if textRequest.Stream {
action = "streamGenerateContent"
}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
case APITypeTencent:
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
}
var promptTokens int
var completionTokens int
switch relayMode {
case RelayModeChatCompletions:
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
switch apiType {
case APITypeClaude:
claudeRequest := requestOpenAI2Claude(textRequest)
jsonStr, err := json.Marshal(claudeRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu:
var jsonData []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduEmbeddingRequest)
default:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeGemini:
geminiChatRequest := requestOpenAI2Gemini(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
var jsonStr []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliEmbeddingRequest)
default:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeTencent:
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
if err != nil {
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
}
tencentRequest := requestOpenAI2Tencent(textRequest)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
jsonStr, err := json.Marshal(tencentRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
sign := getTencentSign(*tencentRequest, secretKey)
c.Request.Header.Set("Authorization", sign)
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAIProxyLibrary:
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
var req *http.Request
var resp *http.Response
isStream := textRequest.Stream
if apiType != APITypeXunfei { // cause xunfei use websocket
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
if channelType == common.ChannelTypeOpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
}
case APITypeClaude:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
case APITypePaLM:
req.Header.Set("x-goog-api-key", apiKey)
case APITypeGemini:
req.Header.Set("x-goog-api-key", apiKey)
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
}
var textResponse TextResponse
tokenName := c.GetString("token_name")
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
switch apiType {
case APITypeOpenAI:
if isStream {
err, responseText := openaiStreamHandler(c, resp, relayMode)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeClaude:
if isStream {
err, responseText := claudeStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeBaidu:
if isStream {
err, usage := baiduStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypePaLM:
if textRequest.Stream { // PaLM2 API does not support stream
err, responseText := palmStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeGemini:
if textRequest.Stream {
err, responseText := geminiChatStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
} else {
err, usage := zhipuHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
}
case APITypeAli:
if isStream {
err, usage := aliStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeXunfei:
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
case APITypeAIProxyLibrary:
if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := aiProxyLibraryHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeTencent:
if isStream {
err, responseText := tencentStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := tencentHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}
}

View File

@@ -1,384 +1,132 @@
package controller
import (
"bytes"
"context"
"fmt"
"net/http"
"one-api/common"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
}
type ImageURL struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
type TextContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
}
type ImageContent struct {
Type string `json:"type,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
type OpenAIMessageContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) IsStringContent() bool {
_, ok := m.Content.(string)
return ok
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
func (m Message) ParseContent() []OpenAIMessageContent {
var contentList []OpenAIMessageContent
content, ok := m.Content.(string)
if ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: content,
})
return contentList
}
anyList, ok := m.Content.([]any)
if ok {
for _, contentItem := range anyList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeImageURL,
ImageURL: &ImageURL{
Url: subObj["url"].(string),
},
})
}
}
}
return contentList
}
return nil
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
// https://platform.openai.com/docs/api-reference/chat
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode
switch relayMode {
case constant.RelayModeImagesGenerations:
err = controller.RelayImageHelper(c, relayMode)
case constant.RelayModeAudioSpeech:
fallthrough
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
default:
err = controller.RelayTextHelper(c)
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
}
type WhisperJSONResponse struct {
Text string `json:"text,omitempty"`
}
type WhisperVerboseJSONResponse struct {
Task string `json:"task,omitempty"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
}
type Segment struct {
Id int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogprob float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
}
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Model string `json:"model,omitempty"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
}
}
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
return err
}
func Relay(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
case RelayModeAudioSpeech:
fallthrough
case RelayModeAudioTranslation:
fallthrough
case RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)
channelId := c.GetInt("channel_id")
bizErr := relay(c, relayMode)
if bizErr == nil {
monitor.Emit(channelId, true)
return
}
if err != nil {
requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
lastFailedChannelId := channelId
channelName := c.GetString("channel_name")
group := c.GetString("group")
originalModel := c.GetString("original_model")
go processChannelRelayError(ctx, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey)
retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) {
logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
retryTimes = 0
}
for i := retryTimes; i > 0; i-- {
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
if err != nil {
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
break
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
}
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
if channel.Id == lastFailedChannelId {
continue
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
bizErr = relay(c, relayMode)
if bizErr == nil {
return
}
channelId := c.GetInt("channel_id")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message)
lastFailedChannelId = channelId
channelName := c.GetString("channel_name")
go processChannelRelayError(ctx, channelId, channelName, bizErr)
}
if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
c.JSON(bizErr.StatusCode, gin.H{
"error": bizErr.Error,
})
}
}
func shouldRetry(c *gin.Context, statusCode int) bool {
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if statusCode == http.StatusTooManyRequests {
return true
}
if statusCode/100 == 5 {
return true
}
if statusCode == http.StatusBadRequest {
return false
}
if statusCode/100 == 2 {
return false
}
return true
}
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message)
} else {
monitor.Emit(channelId, false)
}
}
func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{
err := model.Error{
Message: "API not implemented",
Type: "one_api_error",
Param: "",
@@ -390,7 +138,7 @@ func RelayNotImplemented(c *gin.Context) {
}
func RelayNotFound(c *gin.Context) {
err := OpenAIError{
err := model.Error{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",

View File

@@ -2,9 +2,11 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
)
@@ -14,7 +16,10 @@ func GetAllTokens(c *gin.Context) {
if p < 0 {
p = 0
}
tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage)
order := c.Query("order")
tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -119,9 +124,9 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
Key: helper.GenerateKey(),
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
@@ -137,6 +142,7 @@ func AddToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": cleanToken,
})
return
}
@@ -187,7 +193,7 @@ func UpdateToken(c *gin.Context) {
return
}
if token.Status == common.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",

View File

@@ -3,9 +3,11 @@ package controller
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
@@ -19,7 +21,7 @@ type LoginRequest struct {
}
func Login(c *gin.Context) {
if !common.PasswordLoginEnabled {
if !config.PasswordLoginEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了密码登录",
"success": false,
@@ -106,14 +108,14 @@ func Logout(c *gin.Context) {
}
func Register(c *gin.Context) {
if !common.RegisterEnabled {
if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册",
"success": false,
})
return
}
if !common.PasswordRegisterEnabled {
if !config.PasswordRegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
"success": false,
@@ -136,7 +138,7 @@ func Register(c *gin.Context) {
})
return
}
if common.EmailVerificationEnabled {
if config.EmailVerificationEnabled {
if user.Email == "" || user.VerificationCode == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -160,7 +162,7 @@ func Register(c *gin.Context) {
DisplayName: user.Username,
InviterId: inviterId,
}
if common.EmailVerificationEnabled {
if config.EmailVerificationEnabled {
cleanUser.Email = user.Email
}
if err := cleanUser.Insert(inviterId); err != nil {
@@ -178,24 +180,27 @@ func Register(c *gin.Context) {
}
func GetAllUsers(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
})
return
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
order := c.DefaultQuery("order", "")
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
})
}
func SearchUsers(c *gin.Context) {
@@ -282,7 +287,7 @@ func GenerateAccessToken(c *gin.Context) {
})
return
}
user.AccessToken = common.GetUUID()
user.AccessToken = helper.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{
@@ -319,7 +324,7 @@ func GetAffCode(c *gin.Context) {
return
}
if user.AffCode == "" {
user.AffCode = common.GetRandomString(4)
user.AffCode = helper.GetRandomString(4)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -726,7 +731,7 @@ func EmailBind(c *gin.Context) {
return
}
if user.Role == common.RoleRootUser {
common.RootUserEmail = email
config.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{
"success": true,

View File

@@ -5,9 +5,10 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
@@ -22,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) {
if code == "" {
return "", errors.New("无效的参数")
}
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", common.WeChatServerToken)
req.Header.Set("Authorization", config.WeChatServerToken)
client := http.Client{
Timeout: 5 * time.Second,
}
@@ -50,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) {
}
func WeChatAuth(c *gin.Context) {
if !common.WeChatAuthEnabled {
if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册",
"success": false,
@@ -79,7 +80,7 @@ func WeChatAuth(c *gin.Context) {
return
}
} else {
if common.RegisterEnabled {
if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser
@@ -112,7 +113,7 @@ func WeChatAuth(c *gin.Context) {
}
func WeChatBind(c *gin.Context) {
if !common.WeChatAuthEnabled {
if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册",
"success": false,

View File

@@ -2,7 +2,7 @@ version: '3.4'
services:
one-api:
image: justsong/one-api:latest
image: "${REGISTRY:-docker.io}/justsong/one-api:latest"
container_name: one-api
restart: always
command: --log-dir /app/logs
@@ -29,12 +29,12 @@ services:
retries: 3
redis:
image: redis:latest
image: "${REGISTRY:-docker.io}/redis:latest"
container_name: redis
restart: always
db:
image: mysql:8.2.0
image: "${REGISTRY:-docker.io}/mysql:8.2.0"
restart: always
container_name: mysql
volumes:

8
go.mod
View File

@@ -1,4 +1,4 @@
module one-api
module github.com/songquanpeng/one-api
// +heroku goVersion go1.18
go 1.18
@@ -42,7 +42,8 @@ require (
github.com/gorilla/sessions v1.2.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jackc/pgx/v5 v5.5.4 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
@@ -58,8 +59,9 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

12
go.sum
View File

@@ -73,8 +73,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@@ -157,6 +159,8 @@ golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -177,8 +181,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IV
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

View File

@@ -8,12 +8,12 @@
"确认删除": "Confirm Delete",
"确认绑定": "Confirm Binding",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.",
"\"道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"\"道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"测试已在运行中": "Test is already running",
"响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs",
"道测试完成": "Channel test completed",
"道测试完成,如果没有收到禁用通知,说明所有道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal",
"道测试完成": "Channel test completed",
"道测试完成,如果没有收到禁用通知,说明所有道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal",
"无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!",
"返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!",
"管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub",
@@ -119,11 +119,11 @@
" 个月 ": " M ",
" 年 ": " y ",
"未测试": "Not tested",
"道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
"已成功开始测试所有道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"已成功开始测试所有已启用道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用道余额!": "The balance of all enabled channels has been updated!",
"道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
"已成功开始测试所有道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"已成功开始测试所有已启用道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用道余额!": "The balance of all enabled channels has been updated!",
"搜索渠道的 ID名称和密钥 ...": "Search for channel ID, name and key ...",
"名称": "Name",
"分组": "Group",
@@ -141,9 +141,9 @@
"启用": "Enable",
"编辑": "Edit",
"添加新的渠道": "Add a new channel",
"测试所有道": "Test all channels",
"测试所有已启用道": "Test all enabled channels",
"更新所有已启用道余额": "Update the balance of all enabled channels",
"测试所有道": "Test all channels",
"测试所有已启用道": "Test all enabled channels",
"更新所有已启用道余额": "Update the balance of all enabled channels",
"刷新": "Refresh",
"处理中...": "Processing...",
"绑定成功!": "Binding succeeded!",
@@ -207,11 +207,11 @@
"监控设置": "Monitoring Settings",
"最长响应时间": "Longest Response Time",
"单位秒": "Unit in seconds",
"当运行道全部测试时": "When all operating channels are tested",
"超过此时间将自动禁用道": "Channels will be automatically disabled if this time is exceeded",
"当运行道全部测试时": "When all operating channels are tested",
"超过此时间将自动禁用道": "Channels will be automatically disabled if this time is exceeded",
"额度提醒阈值": "Quota reminder threshold",
"低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this",
"失败时自动禁用道": "Automatically disable the channel when it fails",
"失败时自动禁用道": "Automatically disable the channel when it fails",
"保存监控设置": "Save Monitoring Settings",
"额度设置": "Quota Settings",
"新用户初始额度": "Initial quota for new users",
@@ -405,7 +405,7 @@
"镜像": "Mirror",
"请输入镜像站地址格式为https://domain.com可不填不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used",
"模型": "Model",
"请选择该道所支持的模型": "Please select the model supported by the channel",
"请选择该道所支持的模型": "Please select the model supported by the channel",
"填入基础模型": "Fill in the basic model",
"填入所有模型": "Fill in all models",
"清除所有模型": "Clear all models",
@@ -456,6 +456,7 @@
"已绑定的邮箱账户": "Email Account Bound",
"用户信息更新成功!": "User information updated successfully!",
"模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f",
"模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f": "model rate %.2f, group rate %.2f, completion rate %.2f",
"使用明细(总消耗额度:{renderQuota(stat.quota)}": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})",
"用户名称": "User Name",
"令牌名称": "Token Name",
@@ -514,7 +515,7 @@
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
"Homepage URL 填": "Fill in the Homepage URL",
"Authorization callback URL 填": "Fill in the Authorization callback URL",
"请为道命名": "Please name the channel",
"请为道命名": "Please name the channel",
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
"模型重定向": "Model redirection",
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",

80
main.go
View File

@@ -6,11 +6,14 @@ import (
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/controller"
"one-api/middleware"
"one-api/model"
"one-api/router"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/router"
"os"
"strconv"
)
@@ -19,67 +22,78 @@ import (
var buildFS embed.FS
func main() {
common.SetupLogger()
common.SysLog(fmt.Sprintf("One API %s started with theme %s", common.Version, common.Theme))
logger.SetupLogger()
logger.SysLog(fmt.Sprintf("One API %s started", common.Version))
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
if common.DebugEnabled {
common.SysLog("running in debug mode")
if config.DebugEnabled {
logger.SysLog("running in debug mode")
}
var err error
// Initialize SQL Database
err := model.InitDB()
model.DB, err = model.InitDB("SQL_DSN")
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
logger.FatalLog("failed to initialize database: " + err.Error())
}
if os.Getenv("LOG_SQL_DSN") != "" {
logger.SysLog("using secondary database for table logs")
model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
}
} else {
model.LOG_DB = model.DB
}
err = model.CreateRootAccountIfNeed()
if err != nil {
logger.FatalLog("database init error: " + err.Error())
}
defer func() {
err := model.CloseDB()
if err != nil {
common.FatalLog("failed to close database: " + err.Error())
logger.FatalLog("failed to close database: " + err.Error())
}
}()
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
common.FatalLog("failed to initialize Redis: " + err.Error())
logger.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize options
model.InitOptionMap()
logger.SysLog(fmt.Sprintf("using theme %s", config.Theme))
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true
config.MemoryCacheEnabled = true
}
if common.MemoryCacheEnabled {
common.SysLog("memory cache enabled")
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
if config.MemoryCacheEnabled {
logger.SysLog("memory cache enabled")
logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency))
model.InitChannelCache()
}
if common.MemoryCacheEnabled {
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyUpdateChannels(frequency)
if config.MemoryCacheEnabled {
go model.SyncOptions(config.SyncFrequency)
go model.SyncChannelCache(config.SyncFrequency)
}
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyTestChannels(frequency)
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
config.BatchUpdateEnabled = true
logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
controller.InitTokenEncoders()
if config.EnableMetric {
logger.SysLog("metric enabled, will disable channel if too much request failed")
}
openai.InitTokenEncoders()
// Initialize HTTP server
server := gin.New()
@@ -89,7 +103,7 @@ func main() {
server.Use(middleware.RequestId())
middleware.SetUpLogger(server)
// Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret))
store := cookie.NewStore([]byte(config.SessionSecret))
server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS)
@@ -99,6 +113,6 @@ func main() {
}
err = server.Run(":" + port)
if err != nil {
common.FatalLog("failed to start HTTP server: " + err.Error())
logger.FatalLog("failed to start HTTP server: " + err.Error())
}
}

View File

@@ -3,9 +3,10 @@ package middleware
import (
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strings"
)
@@ -42,11 +43,14 @@ func authHelper(c *gin.Context, minRole int) {
return
}
}
if status.(int) == common.UserStatusDisabled {
if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已被封禁",
})
session := sessions.Default(c)
session.Clear()
_ = session.Save()
c.Abort()
return
}
@@ -99,7 +103,7 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
@@ -108,7 +112,7 @@ func TokenAuth() func(c *gin.Context) {
c.Set("token_name", token.Name)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
c.Set("specific_channel_id", parts[1])
} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return

View File

@@ -2,9 +2,10 @@ package middleware
import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
@@ -20,8 +21,9 @@ func Distribute() func(c *gin.Context) {
userId := c.GetInt("id")
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
var requestModel string
var channel *model.Channel
channelId, ok := c.Get("channelId")
channelId, ok := c.Get("specific_channel_id")
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -65,35 +67,46 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "whisper-1"
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
abortWithMessage(c, http.StatusServiceUnavailable, message)
return
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
SetupContextForSelectedChannel(c, channel, requestModel)
c.Next()
}
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping())
c.Set("original_model", modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility
switch channel.Type {
case common.ChannelTypeAzure:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli:
c.Set(common.ConfigKeyPlugin, channel.Other)
}
cfg, _ := channel.LoadConfig()
for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v)
}
}

View File

@@ -3,14 +3,14 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
"github.com/songquanpeng/one-api/common/logger"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
requestID = param.Keys[common.RequestIdKey].(string)
requestID = param.Keys[logger.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),

View File

@@ -4,8 +4,9 @@ import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"net/http"
"one-api/common"
"time"
)
@@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
}
if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
@@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
// time.Since will return negative number!
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests)
c.Abort()
return
} else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
}
}
}
@@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
}
} else {
// It's safe to call multi times.
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
memoryRateLimiter(c, maxRequestNum, duration, mark)
}
@@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
}
func GlobalWebRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW")
}
func GlobalAPIRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA")
}
func CriticalRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT")
}
func DownloadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW")
return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW")
}
func UploadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP")
}

View File

@@ -3,8 +3,9 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"net/http"
"one-api/common"
"runtime/debug"
)
@@ -12,11 +13,15 @@ func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err))
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
ctx := c.Request.Context()
logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err))
logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path))
body, _ := common.GetRequestBody(c)
logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body)))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
"message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err),
"type": "one_api_panic",
},
})

View File

@@ -3,16 +3,17 @@ package middleware
import (
"context"
"github.com/gin-gonic/gin"
"one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := common.GetTimeString() + common.GetRandomString(8)
c.Set(common.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
id := helper.GenRequestID()
c.Set(logger.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
c.Header(common.RequestIdKey, id)
c.Header(logger.RequestIdKey, id)
c.Next()
}
}

View File

@@ -4,9 +4,10 @@ import (
"encoding/json"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"net/http"
"net/url"
"one-api/common"
)
type turnstileCheckResponse struct {
@@ -15,7 +16,7 @@ type turnstileCheckResponse struct {
func TurnstileCheck() gin.HandlerFunc {
return func(c *gin.Context) {
if common.TurnstileCheckEnabled {
if config.TurnstileCheckEnabled {
session := sessions.Default(c)
turnstileChecked := session.Get("turnstile")
if turnstileChecked != nil {
@@ -32,12 +33,12 @@ func TurnstileCheck() gin.HandlerFunc {
return
}
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
"secret": {common.TurnstileSecretKey},
"secret": {config.TurnstileSecretKey},
"response": {response},
"remoteip": {c.ClientIP()},
})
if err != nil {
common.SysError(err.Error())
logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc {
var res turnstileCheckResponse
err = json.NewDecoder(rawRes.Body).Decode(&res)
if err != nil {
common.SysError(err.Error())
logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),

View File

@@ -2,16 +2,17 @@ package middleware
import (
"github.com/gin-gonic/gin"
"one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
common.LogError(c.Request.Context(), message)
logger.Error(c.Request.Context(), message)
}

View File

@@ -1,7 +1,8 @@
package model
import (
"one-api/common"
"github.com/songquanpeng/one-api/common"
"gorm.io/gorm"
"strings"
)
@@ -13,7 +14,7 @@ type Ability struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
trueVal := "1"
@@ -23,8 +24,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
var channelQuery *gorm.DB
if ignoreFirstPriority {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
} else {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
}
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {

View File

@@ -1,11 +1,14 @@
package model
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
@@ -14,10 +17,10 @@ import (
)
var (
TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency
TokenCacheSeconds = config.SyncFrequency
UserId2GroupCacheSeconds = config.SyncFrequency
UserId2QuotaCacheSeconds = config.SyncFrequency
UserId2StatusCacheSeconds = config.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
@@ -42,7 +45,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
}
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set token error: " + err.Error())
logger.SysError("Redis set token error: " + err.Error())
}
return &token, nil
}
@@ -62,37 +65,48 @@ func CacheGetUserGroup(id int) (group string, err error) {
}
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user group error: " + err.Error())
logger.SysError("Redis set user group error: " + err.Error())
}
}
return group, err
}
func CacheGetUserQuota(id int) (quota int, err error) {
func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) {
quota, err = GetUserQuota(id)
if err != nil {
return 0, err
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil {
logger.Error(ctx, "Redis set user quota error: "+err.Error())
}
return
}
func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) {
if !common.RedisEnabled {
return GetUserQuota(id)
}
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
if err != nil {
quota, err = GetUserQuota(id)
if err != nil {
return 0, err
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user quota error: " + err.Error())
}
return quota, err
return fetchAndUpdateUserQuota(ctx, id)
}
quota, err = strconv.Atoi(quotaString)
return quota, err
quota, err = strconv.ParseInt(quotaString, 10, 64)
if err != nil {
return 0, nil
}
if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db
logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id)
return fetchAndUpdateUserQuota(ctx, id)
}
return quota, nil
}
func CacheUpdateUserQuota(id int) error {
func CacheUpdateUserQuota(ctx context.Context, id int) error {
if !common.RedisEnabled {
return nil
}
quota, err := GetUserQuota(id)
quota, err := CacheGetUserQuota(ctx, id)
if err != nil {
return err
}
@@ -100,7 +114,7 @@ func CacheUpdateUserQuota(id int) error {
return err
}
func CacheDecreaseUserQuota(id int, quota int) error {
func CacheDecreaseUserQuota(id int, quota int64) error {
if !common.RedisEnabled {
return nil
}
@@ -127,7 +141,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
}
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user enabled error: " + err.Error())
logger.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
}
@@ -178,20 +192,20 @@ func InitChannelCache() {
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
logger.SysLog("channels synced from database")
}
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing channels from database")
logger.SysLog("syncing channels from database")
InitChannelCache()
}
}
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
@@ -211,5 +225,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
}
}
idx := rand.Intn(endIdx)
if ignoreFirstPriority {
if endIdx < len(channels) { // which means there are more than one priority
idx = common.RandRange(endIdx, len(channels))
}
}
return channels[idx], nil
}

View File

@@ -1,14 +1,19 @@
package model
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
"one-api/common"
)
type Channel struct {
Id int `json:"id"`
Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null;index"`
Key string `json:"key" gorm:"type:text"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"`
@@ -16,7 +21,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"`
Other string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
@@ -24,25 +29,25 @@ type Channel struct {
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
var channels []*Channel
var err error
if selectAll {
switch scope {
case "all":
err = DB.Order("id desc").Find(&channels).Error
} else {
case "disabled":
err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error
default:
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
}
return channels, err
}
func SearchChannels(keyword string) (channels []*Channel, err error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error
return channels, err
}
@@ -86,11 +91,17 @@ func (channel *Channel) GetBaseURL() string {
return *channel.BaseURL
}
func (channel *Channel) GetModelMapping() string {
if channel.ModelMapping == nil {
return ""
func (channel *Channel) GetModelMapping() map[string]string {
if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" {
return nil
}
return *channel.ModelMapping
modelMapping := make(map[string]string)
err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping)
if err != nil {
logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error()))
return nil
}
return modelMapping
}
func (channel *Channel) Insert() error {
@@ -116,21 +127,21 @@ func (channel *Channel) Update() error {
func (channel *Channel) UpdateResponseTime(responseTime int64) {
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
TestTime: common.GetTimestamp(),
TestTime: helper.GetTimestamp(),
ResponseTime: int(responseTime),
}).Error
if err != nil {
common.SysError("failed to update response time: " + err.Error())
logger.SysError("failed to update response time: " + err.Error())
}
}
func (channel *Channel) UpdateBalance(balance float64) {
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
BalanceUpdatedTime: common.GetTimestamp(),
BalanceUpdatedTime: helper.GetTimestamp(),
Balance: balance,
}).Error
if err != nil {
common.SysError("failed to update balance: " + err.Error())
logger.SysError("failed to update balance: " + err.Error())
}
}
@@ -144,29 +155,41 @@ func (channel *Channel) Delete() error {
return err
}
func (channel *Channel) LoadConfig() (map[string]string, error) {
if channel.Config == "" {
return nil, nil
}
cfg := make(map[string]string)
err := json.Unmarshal([]byte(channel.Config), &cfg)
if err != nil {
return nil, err
}
return cfg, nil
}
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {
common.SysError("failed to update ability status: " + err.Error())
logger.SysError("failed to update ability status: " + err.Error())
}
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
logger.SysError("failed to update channel status: " + err.Error())
}
}
func UpdateChannelUsedQuota(id int, quota int) {
if common.BatchUpdateEnabled {
func UpdateChannelUsedQuota(id int, quota int64) {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return
}
updateChannelUsedQuota(id, quota)
}
func updateChannelUsedQuota(id int, quota int) {
func updateChannelUsedQuota(id int, quota int64) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
common.SysError("failed to update channel used quota: " + err.Error())
logger.SysError("failed to update channel used quota: " + err.Error())
}
}

View File

@@ -3,15 +3,18 @@ package model
import (
"context"
"fmt"
"one-api/common"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
)
type Log struct {
Id int `json:"id;index:idx_created_at_id,priority:1"`
Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
@@ -32,52 +35,52 @@ const (
)
func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled {
if logType == LogTypeConsume && !config.LogConsumeEnabled {
return
}
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
CreatedAt: common.GetTimestamp(),
CreatedAt: helper.GetTimestamp(),
Type: logType,
Content: content,
}
err := DB.Create(log).Error
err := LOG_DB.Create(log).Error
if err != nil {
common.SysError("failed to record log: " + err.Error())
logger.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled {
return
}
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
CreatedAt: common.GetTimestamp(),
CreatedAt: helper.GetTimestamp(),
Type: LogTypeConsume,
Content: content,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
Quota: int(quota),
ChannelId: channelId,
}
err := DB.Create(log).Error
err := LOG_DB.Create(log).Error
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
logger.Error(ctx, "failed to record log: "+err.Error())
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB
tx = LOG_DB
} else {
tx = DB.Where("type = ?", logType)
tx = LOG_DB.Where("type = ?", logType)
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
@@ -104,9 +107,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB.Where("user_id = ?", userId)
tx = LOG_DB.Where("user_id = ?", userId)
} else {
tx = DB.Where("user_id = ? and type = ?", userId, logType)
tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
@@ -125,17 +128,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -159,7 +162,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -180,7 +183,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}
@@ -204,7 +207,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
}
err = DB.Raw(`
err = LOG_DB.Raw(`
SELECT `+groupSelect+`,
model_name, count(1) as request_count,
sum(quota) as quota,

View File

@@ -2,23 +2,28 @@ package model
import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"one-api/common"
"os"
"strings"
"time"
)
var DB *gorm.DB
var LOG_DB *gorm.DB
func createRootAccountIfNeed() error {
func CreateRootAccountIfNeed() error {
var user User
//if user.Status != common.UserStatusEnabled {
//if user.Status != util.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456")
if err != nil {
return err
@@ -29,20 +34,36 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: common.GetUUID(),
Quota: 100000000,
AccessToken: helper.GetUUID(),
Quota: 500000000000000,
}
DB.Create(&rootUser)
if config.InitialRootToken != "" {
logger.SysLog("creating initial root token as requested")
token := Token{
Id: 1,
UserId: rootUser.Id,
Key: config.InitialRootToken,
Status: common.TokenStatusEnabled,
Name: "Initial Root Token",
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: -1,
RemainQuota: 500000000000000,
UnlimitedQuota: true,
}
DB.Create(&token)
}
}
return nil
}
func chooseDB() (*gorm.DB, error) {
if os.Getenv("SQL_DSN") != "" {
dsn := os.Getenv("SQL_DSN")
func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv(envName) != "" {
dsn := os.Getenv(envName)
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
@@ -52,13 +73,14 @@ func chooseDB() (*gorm.DB, error) {
})
}
// Use MySQL
common.SysLog("using MySQL as database")
logger.SysLog("using MySQL as database")
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database")
logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
@@ -66,67 +88,78 @@ func chooseDB() (*gorm.DB, error) {
})
}
func InitDB() (err error) {
db, err := chooseDB()
func InitDB(envName string) (db *gorm.DB, err error) {
db, err = chooseDB(envName)
if err == nil {
if common.DebugEnabled {
if config.DebugSQLEnabled {
db = db.Debug()
}
DB = db
sqlDB, err := DB.DB()
sqlDB, err := db.DB()
if err != nil {
return err
return nil, err
}
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil
if !config.IsMasterNode {
return db, err
}
common.SysLog("database migration started")
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
}
logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return err
return nil, err
}
err = db.AutoMigrate(&Token{})
if err != nil {
return err
return nil, err
}
err = db.AutoMigrate(&User{})
if err != nil {
return err
return nil, err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return err
return nil, err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return err
return nil, err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return err
return nil, err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return err
return nil, err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
logger.SysLog("database migrated")
return db, err
} else {
common.FatalLog(err)
logger.FatalLog(err)
}
return err
return db, err
}
func CloseDB() error {
sqlDB, err := DB.DB()
func closeDB(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
err = sqlDB.Close()
return err
}
func CloseDB() error {
if LOG_DB != DB {
err := closeDB(LOG_DB)
if err != nil {
return err
}
}
return closeDB(DB)
}

View File

@@ -1,7 +1,9 @@
package model
import (
"one-api/common"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"strconv"
"strings"
"time"
@@ -20,69 +22,71 @@ func AllOption() ([]*Option, error) {
}
func InitOptionMap() {
common.OptionMapRWMutex.Lock()
common.OptionMap = make(map[string]string)
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
common.OptionMap["SMTPAccount"] = ""
common.OptionMap["SMTPToken"] = ""
common.OptionMap["Notice"] = ""
common.OptionMap["About"] = ""
common.OptionMap["HomePageContent"] = ""
common.OptionMap["Footer"] = common.Footer
common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = ""
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["WeChatServerToken"] = ""
common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
common.OptionMap["TurnstileSiteKey"] = ""
common.OptionMap["TurnstileSecretKey"] = ""
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["Theme"] = common.Theme
common.OptionMapRWMutex.Unlock()
config.OptionMapRWMutex.Lock()
config.OptionMap = make(map[string]string)
config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled)
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled)
config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled)
config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
config.OptionMap["SMTPServer"] = ""
config.OptionMap["SMTPFrom"] = ""
config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
config.OptionMap["SMTPAccount"] = ""
config.OptionMap["SMTPToken"] = ""
config.OptionMap["Notice"] = ""
config.OptionMap["About"] = ""
config.OptionMap["HomePageContent"] = ""
config.OptionMap["Footer"] = config.Footer
config.OptionMap["SystemName"] = config.SystemName
config.OptionMap["Logo"] = config.Logo
config.OptionMap["ServerAddress"] = ""
config.OptionMap["GitHubClientId"] = ""
config.OptionMap["GitHubClientSecret"] = ""
config.OptionMap["WeChatServerAddress"] = ""
config.OptionMap["WeChatServerToken"] = ""
config.OptionMap["WeChatAccountQRCodeImageURL"] = ""
config.OptionMap["MessagePusherAddress"] = ""
config.OptionMap["MessagePusherToken"] = ""
config.OptionMap["TurnstileSiteKey"] = ""
config.OptionMap["TurnstileSecretKey"] = ""
config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10)
config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10)
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
config.OptionMap["TopUpLink"] = config.TopUpLink
config.OptionMap["ChatLink"] = config.ChatLink
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes)
config.OptionMap["Theme"] = config.Theme
config.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
if option.Key == "ModelRatio" {
option.Value = common.AddNewMissingRatio(option.Value)
}
err := updateOptionMap(option.Key, option.Value)
if err != nil {
common.SysError("failed to update option map: " + err.Error())
logger.SysError("failed to update option map: " + err.Error())
}
}
}
@@ -90,7 +94,7 @@ func loadOptionsFromDatabase() {
func SyncOptions(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing options from database")
logger.SysLog("syncing options from database")
loadOptionsFromDatabase()
}
}
@@ -112,117 +116,110 @@ func UpdateOption(key string, value string) error {
}
func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock()
common.OptionMap[key] = value
if strings.HasSuffix(key, "Permission") {
intValue, _ := strconv.Atoi(value)
switch key {
case "FileUploadPermission":
common.FileUploadPermission = intValue
case "FileDownloadPermission":
common.FileDownloadPermission = intValue
case "ImageUploadPermission":
common.ImageUploadPermission = intValue
case "ImageDownloadPermission":
common.ImageDownloadPermission = intValue
}
}
config.OptionMapRWMutex.Lock()
defer config.OptionMapRWMutex.Unlock()
config.OptionMap[key] = value
if strings.HasSuffix(key, "Enabled") {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
common.PasswordRegisterEnabled = boolValue
config.PasswordRegisterEnabled = boolValue
case "PasswordLoginEnabled":
common.PasswordLoginEnabled = boolValue
config.PasswordLoginEnabled = boolValue
case "EmailVerificationEnabled":
common.EmailVerificationEnabled = boolValue
config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue
config.GitHubOAuthEnabled = boolValue
case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue
config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled":
common.TurnstileCheckEnabled = boolValue
config.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
common.RegisterEnabled = boolValue
config.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue
config.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
config.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
common.AutomaticEnableChannelEnabled = boolValue
config.AutomaticEnableChannelEnabled = boolValue
case "ApproximateTokenEnabled":
common.ApproximateTokenEnabled = boolValue
config.ApproximateTokenEnabled = boolValue
case "LogConsumeEnabled":
common.LogConsumeEnabled = boolValue
config.LogConsumeEnabled = boolValue
case "DisplayInCurrencyEnabled":
common.DisplayInCurrencyEnabled = boolValue
config.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled":
common.DisplayTokenStatEnabled = boolValue
config.DisplayTokenStatEnabled = boolValue
}
}
switch key {
case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",")
config.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer":
common.SMTPServer = value
config.SMTPServer = value
case "SMTPPort":
intValue, _ := strconv.Atoi(value)
common.SMTPPort = intValue
config.SMTPPort = intValue
case "SMTPAccount":
common.SMTPAccount = value
config.SMTPAccount = value
case "SMTPFrom":
common.SMTPFrom = value
config.SMTPFrom = value
case "SMTPToken":
common.SMTPToken = value
config.SMTPToken = value
case "ServerAddress":
common.ServerAddress = value
config.ServerAddress = value
case "GitHubClientId":
common.GitHubClientId = value
config.GitHubClientId = value
case "GitHubClientSecret":
common.GitHubClientSecret = value
config.GitHubClientSecret = value
case "Footer":
common.Footer = value
config.Footer = value
case "SystemName":
common.SystemName = value
config.SystemName = value
case "Logo":
common.Logo = value
config.Logo = value
case "WeChatServerAddress":
common.WeChatServerAddress = value
config.WeChatServerAddress = value
case "WeChatServerToken":
common.WeChatServerToken = value
config.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL":
common.WeChatAccountQRCodeImageURL = value
config.WeChatAccountQRCodeImageURL = value
case "MessagePusherAddress":
config.MessagePusherAddress = value
case "MessagePusherToken":
config.MessagePusherToken = value
case "TurnstileSiteKey":
common.TurnstileSiteKey = value
config.TurnstileSiteKey = value
case "TurnstileSecretKey":
common.TurnstileSecretKey = value
config.TurnstileSecretKey = value
case "QuotaForNewUser":
common.QuotaForNewUser, _ = strconv.Atoi(value)
config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64)
case "QuotaForInviter":
common.QuotaForInviter, _ = strconv.Atoi(value)
config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64)
case "QuotaForInvitee":
common.QuotaForInvitee, _ = strconv.Atoi(value)
config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64)
case "PreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value)
config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
config.TopUpLink = value
case "ChatLink":
common.ChatLink = value
config.ChatLink = value
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "Theme":
common.Theme = value
config.Theme = value
}
return err
}

View File

@@ -3,8 +3,9 @@ package model
import (
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"gorm.io/gorm"
"one-api/common"
)
type Redemption struct {
@@ -13,7 +14,7 @@ type Redemption struct {
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Quota int `json:"quota" gorm:"default:100"`
Quota int64 `json:"quota" gorm:"bigint;default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request
@@ -41,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
return &redemption, err
}
func Redeem(key string, userId int) (quota int, err error) {
func Redeem(key string, userId int) (quota int64, err error) {
if key == "" {
return 0, errors.New("未提供兑换码")
}
@@ -67,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return err
}
redemption.RedeemedTime = common.GetTimestamp()
redemption.RedeemedTime = helper.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error
return err

View File

@@ -3,8 +3,12 @@ package model
import (
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"gorm.io/gorm"
"one-api/common"
)
type Token struct {
@@ -16,15 +20,26 @@ type Token struct {
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
var tokens []*Token
var err error
err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error
query := DB.Where("user_id = ?", userId)
switch order {
case "remain_quota":
query = query.Order("unlimited_quota desc, remain_quota desc")
case "used_quota":
query = query.Order("used_quota desc")
default:
query = query.Order("id desc")
}
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
return tokens, err
}
@@ -39,7 +54,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
}
token, err = CacheGetTokenByKey(key)
if err != nil {
common.SysError("CacheGetTokenByKey failed: " + err.Error())
logger.SysError("CacheGetTokenByKey failed: " + err.Error())
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("无效的令牌")
}
@@ -53,12 +68,12 @@ func ValidateUserToken(key string) (token *Token, err error) {
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
logger.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌已过期")
@@ -69,7 +84,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token status" + err.Error())
logger.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌额度已用尽")
@@ -134,51 +149,51 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete()
}
func IncreaseTokenQuota(id int, quota int) (err error) {
func IncreaseTokenQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
}
return increaseTokenQuota(id, quota)
}
func increaseTokenQuota(id int, quota int) (err error) {
func increaseTokenQuota(id int, quota int64) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota),
"accessed_time": common.GetTimestamp(),
"accessed_time": helper.GetTimestamp(),
},
).Error
return err
}
func DecreaseTokenQuota(id int, quota int) (err error) {
func DecreaseTokenQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
}
return decreaseTokenQuota(id, quota)
}
func decreaseTokenQuota(id int, quota int) (err error) {
func decreaseTokenQuota(id int, quota int64) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota),
"accessed_time": common.GetTimestamp(),
"accessed_time": helper.GetTimestamp(),
},
).Error
return err
}
func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@@ -196,24 +211,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if userQuota < quota {
return errors.New("用户额度不足")
}
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold
quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold
noMoreQuota := userQuota-quota <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(token.UserId)
if err != nil {
common.SysError("failed to fetch user email: " + err.Error())
logger.SysError("failed to fetch user email: " + err.Error())
}
prompt := "您的额度即将用尽"
if noMoreQuota {
prompt = "您的额度已用尽"
}
if email != "" {
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
err = common.SendEmail(prompt, email,
topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress)
err = message.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil {
common.SysError("failed to send email" + err.Error())
logger.SysError("failed to send email" + err.Error())
}
}
}()
@@ -228,7 +243,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
return err
}
func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)

View File

@@ -3,8 +3,12 @@ package model
import (
"errors"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
"one-api/common"
"strings"
)
@@ -15,16 +19,16 @@ type User struct {
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Role int `json:"role" gorm:"type:int;default:1"` // admin, util
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Quota int64 `json:"quota" gorm:"bigint;default:0"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
@@ -36,9 +40,22 @@ func GetMaxUserId() int {
return user.Id
}
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
return users, err
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
switch order {
case "quota":
query = query.Order("quota desc")
case "used_quota":
query = query.Order("used_quota desc")
case "request_count":
query = query.Order("request_count desc")
default:
query = query.Order("id desc")
}
err = query.Find(&users).Error
return users, err
}
func SearchUsers(keyword string) (users []*User, err error) {
@@ -89,24 +106,24 @@ func (user *User) Insert(inviterId int) error {
return err
}
}
user.Quota = common.QuotaForNewUser
user.AccessToken = common.GetUUID()
user.AffCode = common.GetRandomString(4)
user.Quota = config.QuotaForNewUser
user.AccessToken = helper.GetUUID()
user.AffCode = helper.GetRandomString(4)
result := DB.Create(user)
if result.Error != nil {
return result.Error
}
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
if config.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
if config.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
if config.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
}
}
return nil
@@ -120,6 +137,11 @@ func (user *User) Update(updatePassword bool) error {
return err
}
}
if user.Status == common.UserStatusDisabled {
blacklist.BanUser(user.Id)
} else if user.Status == common.UserStatusEnabled {
blacklist.UnbanUser(user.Id)
}
err = DB.Model(user).Updates(user).Error
return err
}
@@ -128,7 +150,10 @@ func (user *User) Delete() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
err := DB.Delete(user).Error
blacklist.BanUser(user.Id)
user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID())
user.Status = common.UserStatusDeleted
err := DB.Model(user).Updates(user).Error
return err
}
@@ -141,7 +166,15 @@ func (user *User) ValidateAndFill() (err error) {
if user.Username == "" || password == "" {
return errors.New("用户名或密码为空")
}
DB.Where(User{Username: user.Username}).First(user)
err = DB.Where("username = ?", user.Username).First(user).Error
if err != nil {
// we must make sure check username firstly
// consider this case: a malicious user set his username as other's email
err := DB.Where("email = ?", user.Username).First(user).Error
if err != nil {
return errors.New("用户名或密码错误,或用户已被封禁")
}
}
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
@@ -224,7 +257,7 @@ func IsAdmin(userId int) bool {
var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil {
common.SysError("no such user " + err.Error())
logger.SysError("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
@@ -254,12 +287,12 @@ func ValidateAccessToken(token string) (user *User) {
return nil
}
func GetUserQuota(id int) (quota int, err error) {
func GetUserQuota(id int) (quota int64, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
return quota, err
}
func GetUserUsedQuota(id int) (quota int, err error) {
func GetUserUsedQuota(id int) (quota int64, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find(&quota).Error
return quota, err
}
@@ -279,34 +312,34 @@ func GetUserGroup(id int) (group string, err error) {
return group, err
}
func IncreaseUserQuota(id int, quota int) (err error) {
func IncreaseUserQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
return increaseUserQuota(id, quota)
}
func increaseUserQuota(id int, quota int) (err error) {
func increaseUserQuota(id int, quota int64) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
return err
}
func DecreaseUserQuota(id int, quota int) (err error) {
func DecreaseUserQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.BatchUpdateEnabled {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
}
return decreaseUserQuota(id, quota)
}
func decreaseUserQuota(id int, quota int) (err error) {
func decreaseUserQuota(id int, quota int64) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err
}
@@ -316,8 +349,8 @@ func GetRootUserEmail() (email string) {
return email
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled {
func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return
@@ -325,7 +358,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
updateUserUsedQuotaAndRequestCount(id, quota, 1)
}
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
@@ -333,25 +366,25 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
},
).Error
if err != nil {
common.SysError("failed to update user used quota and request count: " + err.Error())
logger.SysError("failed to update user used quota and request count: " + err.Error())
}
}
func updateUserUsedQuota(id int, quota int) {
func updateUserUsedQuota(id int, quota int64) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
},
).Error
if err != nil {
common.SysError("failed to update user used quota: " + err.Error())
logger.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
common.SysError("failed to update user request count: " + err.Error())
logger.SysError("failed to update user request count: " + err.Error())
}
}

View File

@@ -1,7 +1,8 @@
package model
import (
"one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"sync"
"time"
)
@@ -15,12 +16,12 @@ const (
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
)
var batchUpdateStores []map[int]int
var batchUpdateStores []map[int]int64
var batchUpdateLocks []sync.Mutex
func init() {
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
batchUpdateStores = append(batchUpdateStores, make(map[int]int64))
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
}
}
@@ -28,13 +29,13 @@ func init() {
func InitBatchUpdater() {
go func() {
for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
}
func addNewRecord(type_ int, id int, value int) {
func addNewRecord(type_ int, id int, value int64) {
batchUpdateLocks[type_].Lock()
defer batchUpdateLocks[type_].Unlock()
if _, ok := batchUpdateStores[type_][id]; !ok {
@@ -45,11 +46,11 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
common.SysLog("batch update started")
logger.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateStores[i] = make(map[int]int64)
batchUpdateLocks[i].Unlock()
// TODO: maybe we can combine updates with same key?
for key, value := range store {
@@ -57,21 +58,21 @@ func batchUpdate() {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
common.SysError("failed to batch update user quota: " + err.Error())
logger.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
common.SysError("failed to batch update token quota: " + err.Error())
logger.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value)
updateUserRequestCount(key, int(value))
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}
}
}
common.SysLog("batch update finished")
logger.SysLog("batch update finished")
}

55
monitor/channel.go Normal file
View File

@@ -0,0 +1,55 @@
package monitor
import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/model"
)
func notifyRootUser(subject string, content string) {
if config.MessagePusherAddress != "" {
err := message.SendMessage(subject, content, content)
if err != nil {
logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error()))
} else {
return
}
}
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
}
err := message.SendEmail(subject, config.RootUserEmail, content)
if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// DisableChannel disable & notify
func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
subject := fmt.Sprintf("渠道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
func MetricDisableChannel(channelId int, successRate float64) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道(#%d在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100)
notifyRootUser(subject, content)
}
// EnableChannel enable & notify
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
subject := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}

79
monitor/metric.go Normal file
View File

@@ -0,0 +1,79 @@
package monitor
import (
"github.com/songquanpeng/one-api/common/config"
)
var store = make(map[int][]bool)
var metricSuccessChan = make(chan int, config.MetricSuccessChanSize)
var metricFailChan = make(chan int, config.MetricFailChanSize)
func consumeSuccess(channelId int) {
if len(store[channelId]) > config.MetricQueueSize {
store[channelId] = store[channelId][1:]
}
store[channelId] = append(store[channelId], true)
}
func consumeFail(channelId int) (bool, float64) {
if len(store[channelId]) > config.MetricQueueSize {
store[channelId] = store[channelId][1:]
}
store[channelId] = append(store[channelId], false)
successCount := 0
for _, success := range store[channelId] {
if success {
successCount++
}
}
successRate := float64(successCount) / float64(len(store[channelId]))
if len(store[channelId]) < config.MetricQueueSize {
return false, successRate
}
if successRate < config.MetricSuccessRateThreshold {
store[channelId] = make([]bool, 0)
return true, successRate
}
return false, successRate
}
func metricSuccessConsumer() {
for {
select {
case channelId := <-metricSuccessChan:
consumeSuccess(channelId)
}
}
}
func metricFailConsumer() {
for {
select {
case channelId := <-metricFailChan:
disable, successRate := consumeFail(channelId)
if disable {
go MetricDisableChannel(channelId, successRate)
}
}
}
}
func init() {
if config.EnableMetric {
go metricSuccessConsumer()
go metricFailConsumer()
}
}
func Emit(channelId int, success bool) {
if !config.EnableMetric {
return
}
go func() {
if success {
metricSuccessChan <- channelId
} else {
metricFailChan <- channelId
}
}()
}

View File

@@ -1,9 +1,10 @@
[//]: # (请按照以下格式关联 issue)
[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢)
[//]: # (请在提交 PR 前确认所提交的功能可用,需要附上截图,谢谢)
[//]: # (项目维护者一般仅在周末处理 PR因此如若未能及时回复希望能理解)
[//]: # (开发者交流群910657413)
[//]: # (请在提交 PR 之前删除上面的注释)
close #issue_number
我已确认该 PR 已自测通过,相关截图如下:
我已确认该 PR 已自测通过,相关截图如下:
(此处放上测试通过的截图,如果不涉及前端改动或从 UI 上无法看出,请放终端启动成功的截图)

View File

@@ -0,0 +1,8 @@
package ai360
var ModelList = []string{
"360GPT_S2_V9",
"embedding-bert-512-v1",
"embedding_s1_v1",
"semantic_similarity_s1_v1",
}

View File

@@ -0,0 +1,60 @@
package aiproxy
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
return aiProxyLibraryRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "aiproxy"
}

View File

@@ -0,0 +1,9 @@
package aiproxy
import "github.com/songquanpeng/one-api/relay/channel/openai"
var ModelList = []string{""}
func init() {
ModelList = openai.ModelList
}

View File

@@ -1,63 +1,37 @@
package controller
package aiproxy
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"one-api/common"
"strconv"
"strings"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
type AIProxyLibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type AIProxyLibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type AIProxyLibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type AIProxyLibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []AIProxyLibraryDocument `json:"documents"`
AIProxyLibraryError
}
type AIProxyLibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []AIProxyLibraryDocument `json:"documents"`
}
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].StringContent()
}
return &AIProxyLibraryRequest{
return &LibraryRequest{
Model: request.Model,
Stream: request.Stream,
Query: query,
}
}
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
func aiProxyDocuments2Markdown(documents []LibraryDocument) string {
if len(documents) == 0 {
return ""
}
@@ -68,52 +42,52 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
return content
}
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse {
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
choice := OpenAITextResponseChoice{
choice := openai.TextResponseChoice{
Index: 0,
Message: Message{
Message: model.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
Id: common.GetUUID(),
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &stopFinishReason
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Created: helper.GetTimestamp(),
Model: "",
Choices: []ChatCompletionsStreamResponseChoice{choice},
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Created: helper.GetTimestamp(),
Model: response.Model,
Choices: []ChatCompletionsStreamResponseChoice{choice},
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -143,15 +117,15 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
}
stopChan <- true
}()
setEventStreamHeaders(c)
var documents []AIProxyLibraryDocument
common.SetEventStreamHeaders(c)
var documents []LibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
@@ -160,7 +134,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -169,7 +143,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -179,28 +153,28 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var AIProxyLibraryResponse AIProxyLibraryResponse
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var AIProxyLibraryResponse LibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
@@ -211,10 +185,13 @@ func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &fullTextResponse.Usage
}

View File

@@ -0,0 +1,32 @@
package aiproxy
type LibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type LibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type LibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type LibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []LibraryDocument `json:"documents"`
LibraryError
}
type LibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []LibraryDocument `json:"documents"`
}

View File

@@ -0,0 +1,86 @@
package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
if meta.IsStream {
req.Header.Set("Accept", "text/event-stream")
}
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.IsStream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString(common.ConfigKeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "ali"
}

View File

@@ -0,0 +1,6 @@
package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
}

267
relay/channel/ali/main.go Normal file
View File

@@ -0,0 +1,267 @@
package ali
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
const EnableSearchModelSuffix = "-internet"
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, Message{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
if request.TopP >= 1 {
request.TopP = 0.9999
}
return &ChatRequest{
Model: aliModel,
Input: Input{
Messages: messages,
},
Parameters: Parameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
TopK: request.TopK,
ResultFormat: "message",
},
Tools: request.Tools,
}
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: response.Output.Choices,
Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
if len(aliResponse.Output.Choices) == 0 {
return nil
}
aliChoice := aliResponse.Output.Choices[0]
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta = aliChoice.Message
if aliChoice.FinishReason != "null" {
finishReason := aliChoice.FinishReason
choice.FinishReason = &finishReason
}
response := openai.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "qwen",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
//lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse ChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
return true
}
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := c.Request.Context()
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
logger.Debugf(ctx, "response body: %s\n", responseBody)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
fullTextResponse.Model = "qwen"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -0,0 +1,81 @@
package ali
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
}
type Input struct {
//Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type Embedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type EmbeddingResponse struct {
Output struct {
Embeddings []Embedding `json:"embeddings"`
} `json:"output"`
Usage Usage `json:"usage"`
Error
}
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Output struct {
//Text string `json:"text"`
//FinishReason string `json:"finish_reason"`
Choices []openai.TextResponseChoice `json:"choices"`
}
type ChatResponse struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
Error
}

View File

@@ -0,0 +1,63 @@
package anthropic
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-api-key", meta.APIKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
req.Header.Set("anthropic-beta", "messages-2023-12-15")
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "anthropic"
}

View File

@@ -0,0 +1,8 @@
package anthropic
var ModelList = []string{
"claude-instant-1.2", "claude-2.0", "claude-2.1",
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
}

View File

@@ -0,0 +1,272 @@
package anthropic
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
func stopReasonClaude2OpenAI(reason *string) string {
if reason == nil {
return ""
}
switch *reason {
case "end_turn":
return "stop"
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return *reason
}
}
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeRequest := Request{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
// legacy model name mapping
if claudeRequest.Model == "claude-instant-1" {
claudeRequest.Model = "claude-instant-1.1"
} else if claudeRequest.Model == "claude-2" {
claudeRequest.Model = "claude-2.1"
}
for _, message := range textRequest.Messages {
if message.Role == "system" && claudeRequest.System == "" {
claudeRequest.System = message.StringContent()
continue
}
claudeMessage := Message{
Role: message.Role,
}
var content Content
if message.IsStringContent() {
content.Type = "text"
content.Text = message.StringContent()
claudeMessage.Content = append(claudeMessage.Content, content)
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
continue
}
var contents []Content
openaiContent := message.ParseContent()
for _, part := range openaiContent {
var content Content
if part.Type == model.ContentTypeText {
content.Type = "text"
content.Text = part.Text
} else if part.Type == model.ContentTypeImageURL {
content.Type = "image"
content.Source = &ImageSource{
Type: "base64",
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
content.Source.MediaType = mimeType
content.Source.Data = data
}
contents = append(contents, content)
}
claudeMessage.Content = contents
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
}
return &claudeRequest
}
// https://docs.anthropic.com/claude/reference/messages-streaming
func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var stopReason string
switch claudeResponse.Type {
case "message_start":
return nil, claudeResponse.Message
case "content_block_start":
if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text
}
case "content_block_delta":
if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text
}
case "message_delta":
if claudeResponse.Usage != nil {
response = &Response{
Usage: *claudeResponse.Usage,
}
}
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
stopReason = *claudeResponse.Delta.StopReason
}
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse, response
}
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: responseText,
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
Model: claudeResponse.Model,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 {
continue
}
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var usage model.Usage
var modelName string
var id string
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse StreamResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, meta := streamResponseClaude2OpenAI(&claudeResponse)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
}
if response == nil {
return true
}
response.Id = id
response.Model = modelName
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
_ = resp.Body.Close()
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse Response
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = modelName
usage := model.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,75 @@
package anthropic
// https://docs.anthropic.com/claude/reference/messages_post
type Metadata struct {
UserId string `json:"user_id"`
}
type ImageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
}
type Content struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ImageSource `json:"source,omitempty"`
}
type Message struct {
Role string `json:"role"`
Content []Content `json:"content"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//Metadata `json:"metadata,omitempty"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []Content `json:"content"`
Model string `json:"model"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Usage Usage `json:"usage"`
Error Error `json:"error"`
}
type Delta struct {
Type string `json:"type"`
Text string `json:"text"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
}
type StreamResponse struct {
Type string `json:"type"`
Message *Response `json:"message"`
Index int `json:"index"`
ContentBlock *Content `json:"content_block"`
Delta *Delta `json:"delta"`
Usage *Usage `json:"usage"`
}

View File

@@ -0,0 +1,7 @@
package baichuan
var ModelList = []string{
"Baichuan2-Turbo",
"Baichuan2-Turbo-192k",
"Baichuan-Text-Embedding",
}

View File

@@ -0,0 +1,118 @@
package baidu
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(meta.ActualModelName, "Embedding") {
suffix = "embeddings/"
}
if strings.HasPrefix(meta.ActualModelName, "bge-large") {
suffix = "embeddings/"
}
if strings.HasPrefix(meta.ActualModelName, "tao-8k") {
suffix = "embeddings/"
}
switch meta.ActualModelName {
case "ERNIE-4.0":
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
suffix += "embedding-v1"
case "bge-large-zh":
suffix += "bge_large_zh"
case "bge-large-en":
suffix += "bge_large_en"
case "tao-8k":
suffix += "tao_8k"
default:
suffix += meta.ActualModelName
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
var accessToken string
var err error
if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
return "", err
}
fullRequestURL += "?access_token=" + accessToken
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "baidu"
}

View File

@@ -0,0 +1,13 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",
"tao-8k",
}

View File

@@ -1,4 +1,4 @@
package controller
package baidu
import (
"bufio"
@@ -6,9 +6,14 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"one-api/common"
"strings"
"sync"
"time"
@@ -16,148 +21,111 @@ import (
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
type BaiduTokenResponse struct {
type TokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}
type BaiduMessage struct {
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
type ChatRequest struct {
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
}
type BaiduError struct {
type Error struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
baiduRequest := ChatRequest{
Messages: make([]Message, 0, len(request.Messages)),
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
MaxOutputTokens: request.MaxTokens,
UserId: request.User,
}
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Content: "Okay",
})
baiduRequest.System = message.StringContent()
} else {
messages = append(messages, BaiduMessage{
baiduRequest.Messages = append(baiduRequest.Messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
}
return &BaiduChatRequest{
Messages: messages,
Stream: request.Stream,
}
return &baiduRequest
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: Message{
Message: model.Message{
Role: "assistant",
Content: response.Result,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
fullTextResponse := openai.TextResponse{
Id: response.Id,
Object: "chat.completion",
Created: response.Created,
Choices: []OpenAITextResponseChoice{choice},
Choices: []openai.TextResponseChoice{choice},
Usage: response.Usage,
}
return &fullTextResponse
}
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &stopFinishReason
choice.FinishReason = &constant.StopFinishReason
}
response := ChatCompletionsStreamResponse{
response := openai.ChatCompletionsStreamResponse{
Id: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Input: request.ParseInput(),
}
}
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
@@ -166,8 +134,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
return &openAIEmbeddingResponse
}
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -194,14 +162,14 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
}
stopChan <- true
}()
setEventStreamHeaders(c)
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var baiduResponse BaiduChatStreamResponse
var baiduResponse ChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
@@ -212,7 +180,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -224,28 +192,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var baiduResponse BaiduChatResponse
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -258,7 +226,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
fullTextResponse.Model = "ernie-bot"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -266,23 +234,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
return nil, &fullTextResponse.Usage
}
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var baiduResponse BaiduEmbeddingResponse
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -294,7 +262,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -302,10 +270,10 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
return nil, &fullTextResponse.Usage
}
func getBaiduAccessToken(apiKey string) (string, error) {
func GetAccessToken(apiKey string) (string, error) {
if val, ok := baiduTokenStore.Load(apiKey); ok {
var accessToken BaiduAccessToken
if accessToken, ok = val.(BaiduAccessToken); ok {
var accessToken AccessToken
if accessToken, ok = val.(AccessToken); ok {
// soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go func() {
@@ -320,12 +288,12 @@ func getBaiduAccessToken(apiKey string) (string, error) {
return "", err
}
if accessToken == nil {
return "", errors.New("getBaiduAccessToken return a nil token")
return "", errors.New("GetAccessToken return a nil token")
}
return (*accessToken).AccessToken, nil
}
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
@@ -337,13 +305,13 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := impatientHTTPClient.Do(req)
res, err := util.ImpatientHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
var accessToken BaiduAccessToken
var accessToken AccessToken
err = json.NewDecoder(res.Body).Decode(&accessToken)
if err != nil {
return nil, err

View File

@@ -0,0 +1,50 @@
package baidu
import (
"github.com/songquanpeng/one-api/relay/model"
"time"
)
type ChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage model.Usage `json:"usage"`
Error
}
type ChatStreamResponse struct {
ChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type EmbeddingRequest struct {
Input []string `json:"input"`
}
type EmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type EmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []EmbeddingData `json:"data"`
Usage model.Usage `json:"usage"`
Error
}
type AccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}

51
relay/channel/common.go Normal file
View File

@@ -0,0 +1,51 @@
package channel
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, req, meta)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := DoRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := util.HTTPClient.Do(req)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
}

View File

@@ -0,0 +1,66 @@
package gemini
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1")
action := "generateContent"
if meta.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "google gemini"
}

View File

@@ -0,0 +1,8 @@
package gemini
// https://ai.google.dev/models/gemini
var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
}

View File

@@ -1,13 +1,19 @@
package controller
package gemini
import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"one-api/common"
"one-api/common/image"
"strings"
"github.com/gin-gonic/gin"
@@ -16,79 +22,39 @@ import (
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
GeminiVisionMaxImageNum = 16
VisionMaxImageNum = 16
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest := ChatRequest{
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
SafetySettings: []ChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: common.GeminiSafetySetting,
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: common.GeminiSafetySetting,
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: common.GeminiSafetySetting,
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: common.GeminiSafetySetting,
Threshold: config.GeminiSafetySetting,
},
},
GenerationConfig: GeminiChatGenerationConfig{
GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
geminiRequest.Tools = []ChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
@@ -96,30 +62,30 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := GeminiChatContent{
content := ChatContent{
Role: message.Role,
Parts: []GeminiPart{
Parts: []Part{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
var parts []Part
imageNum := 0
for _, part := range openaiContent {
if part.Type == ContentTypeText {
parts = append(parts, GeminiPart{
if part.Type == model.ContentTypeText {
parts = append(parts, Part{
Text: part.Text,
})
} else if part.Type == ContentTypeImageURL {
} else if part.Type == model.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
if imageNum > VisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
parts = append(parts, Part{
InlineData: &InlineData{
MimeType: mimeType,
Data: data,
},
@@ -141,9 +107,9 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
Role: "model",
Parts: []GeminiPart{
Parts: []Part{
{
Text: "Okay",
},
@@ -156,12 +122,12 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
return &geminiRequest
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
}
func (g *GeminiChatResponse) GetResponseText() string {
func (g *ChatResponse) GetResponseText() string {
if g == nil {
return ""
}
@@ -171,37 +137,37 @@ func (g *GeminiChatResponse) GetResponseText() string {
return ""
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
type ChatCandidate struct {
Content ChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
type ChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
type ChatPromptFeedback struct {
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := OpenAITextResponseChoice{
choice := openai.TextResponseChoice{
Index: i,
Message: Message{
Message: model.Message{
Role: "assistant",
Content: "",
},
FinishReason: stopFinishReason,
FinishReason: constant.StopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
@@ -211,18 +177,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &stopFinishReason
var response ChatCompletionsStreamResponse
choice.FinishReason = &constant.StopFinishReason
var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
@@ -252,7 +218,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
}
stopChan <- true
}()
setEventStreamHeaders(c)
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -264,18 +230,18 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice ChatCompletionsStreamResponseChoice
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []ChatCompletionsStreamResponseChoice{choice},
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -287,28 +253,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse GeminiChatResponse
var geminiResponse ChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: "No candidates returned",
Type: "server_error",
Param: "",
@@ -318,9 +284,9 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Model = model
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
usage := Usage{
fullTextResponse.Model = modelName
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
@@ -328,7 +294,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)

Some files were not shown because too many files have changed in this diff Show More