Compare commits

...

227 Commits

Author SHA1 Message Date
JustSong
065da8ef8c fix: fix ali function call (#1242) 2024-04-04 00:46:30 +08:00
JustSong
e3cfb1fa52 feat: use given usage if available in stream mode 2024-03-31 23:41:52 +08:00
JustSong
f89ae5ad58 feat: initial function call support for xunfei 2024-03-31 23:12:29 +08:00
JustSong
06a3fc5421 chore: update GeneralOpenAIRequest 2024-03-31 22:23:42 +08:00
ManJieqi
a9c464ec5a fix: update model-ratio.go 修正文心计费模型名称
统一文心计费模型名称
2024-03-30 11:06:31 +08:00
JustSong
3f3c13c98c feat: support top_k for claude (close #1239) 2024-03-30 10:47:07 +08:00
JustSong
2ba28c72cb feat: support function call for ali (close #1242) 2024-03-30 10:43:26 +08:00
JustSong
5e81e19bc8 fix: fix SQL channel selection algo (#1197) 2024-03-27 19:09:27 +08:00
JustSong
96d7a99312 fix: fix autofilled models are not correct 2024-03-24 23:12:32 +08:00
JustSong
24be9de098 chore: update copy 2024-03-24 23:01:03 +08:00
JustSong
5b349efff9 chore: fix berry copy 2024-03-24 22:57:24 +08:00
JustSong
f76c46d648 feat: add gemini-1.5-pro (#1211) 2024-03-24 22:50:09 +08:00
JustSong
cdfdeea3b4 feat: return token when calling post /api/token (close #1208) 2024-03-24 22:24:41 +08:00
JustSong
56ddbb842a fix: return pre-consumed quota when error happened for audio (close #1217) 2024-03-24 22:20:41 +08:00
JustSong
99f81a267c fix: fix xunfei error handling (close #1218) 2024-03-24 22:14:45 +08:00
xietong
c243cd5535 feat: 支持 ollama 的 embedding 接口 (#1221)
* 增加ollama的embedding接口

* chore: fix function name

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-03-24 21:51:31 +08:00
GuangxiaoLong
e96b173abe feat: 移除 azure model 的 TrimSuffix (#1193) 2024-03-24 21:47:46 +08:00
Benny
4ae311e964 docs: update README (#1186) 2024-03-17 21:06:36 +08:00
JustSong
b14cb748d8 chore: update copy 2024-03-17 19:39:00 +08:00
Ian Li
ade19ba4a2 feat: update default API version for Azure OpenAI (#994)
* feat: Update default API version for Azure OpenAI.

* chore: update other theme

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-03-17 19:34:21 +08:00
Ian Li
4d86d021c4 feat: support Azure OpenAI TTS. (#1177) 2024-03-17 19:30:50 +08:00
shuirong
7a44adb5a7 fix: fix panel cards style (#1171) 2024-03-17 19:26:12 +08:00
Benny
9821bc7281 feat: add user list sorting and pagination enhancements (#1178)
* feat: add user list sorting and pagination enhancements

* feat: add user list sorting for THEME=air

* feat: add token list sorting and pagination enhancements

* feat: add token list sorting for THEME=air
2024-03-17 19:25:36 +08:00
JustSong
08831881f1 feat: increase initial root user quota and support INITIAL_ROOT_TOKEN now (#1105) 2024-03-17 19:09:44 +08:00
JustSong
0eb2272bb7 chore: update copy 2024-03-17 18:12:49 +08:00
JustSong
704ec1a827 chore: update theme berry 2024-03-17 17:48:57 +08:00
Ghostz
1d7470d6ad fix: fix lingyiwanwu model ratio (#1182) 2024-03-17 17:04:29 +08:00
JustSong
1185303346 chore: update comments 2024-03-17 14:10:35 +08:00
JustSong
c212fcf8d7 docs: update readme 2024-03-17 14:00:33 +08:00
JustSong
c285e000cc chore: remove default scroll bar 2024-03-16 16:16:44 +08:00
JustSong
d25ed4c009 chore: update name 2024-03-16 15:55:31 +08:00
JustSong
7400885fbb fix: fix error 2024-03-16 15:41:43 +08:00
GAI Group
11af81eb39 feat: add new theme air (#1167)
* chore: add theme air with new-api main branch v0.2.0.3-alpha.1(first step)

* feat: 完成渠道界面

* chore: 优化渠道界面样式问题

* feat: 完成兑换码界面

* feat: 完成充值(钱包)界面

* chore: 初代air主题将使用default主题的运营设置界面、系统设置界面、其他设置界面

* feat: 完成日志界面

* feat: 完成用户管理界面

* feat: 完成个人设置界面

* feat: 完成令牌界面

* chore: 优化令牌界面逻辑

* feat: 修改版权信息

* chore: make necessary changes

---------

Co-authored-by: Calon <1808837298@qq.com>
Co-authored-by: Apple\Apple <zeraturing@foxmail.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-03-16 15:29:35 +08:00
majian
205aba694f chore: limit the temperature and top_p parameter value range to (0.0, 1) for zhipu (#1091) 2024-03-16 13:39:30 +08:00
dependabot[bot]
8dac3afebc chore(deps): bump github.com/jackc/pgx/v5 from 5.3.1 to 5.5.4 (#1157)
Bumps [github.com/jackc/pgx/v5](https://github.com/jackc/pgx) from 5.3.1 to 5.5.4.
- [Changelog](https://github.com/jackc/pgx/blob/master/CHANGELOG.md)
- [Commits](https://github.com/jackc/pgx/compare/v5.3.1...v5.5.4)

---
updated-dependencies:
- dependency-name: github.com/jackc/pgx/v5
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-03-16 13:31:59 +08:00
zu1k
a07791bf93 fix: change Moonshot value to 25 (#1158) 2024-03-16 13:29:19 +08:00
JustSong
4bb662c0e4 docs: update pull_request_template.md 2024-03-16 13:28:44 +08:00
Benny
4998d58319 fix: fix ratio of gpt-3.5-turbo (close #1011) (#1163) 2024-03-16 13:26:11 +08:00
E.da
190203cf8f fix: 修复berry主题下令牌编辑后点击新建弹窗值的初始化问题 (#1165)
* Update OtherSetting.js

调整berry主题页脚`label`表述

* 修复`berry`主题下`令牌`在`修改`后点击`新建`弹窗值的初始化问题

### 问题描述

在`berry`主题中,存在一个问题,当用户在修改一个令牌后点击新建令牌时,新建令牌弹窗会错误地展示上一次编辑的值。

### 复现步骤

1. 导航至`令牌`管理页面。
2. 选择任意令牌进行`编辑`。
3. 在编辑界面,点击`取消`返回。
4. 点击`+新建令牌`按钮打开新建令牌弹窗。

### 预期结果

新建令牌弹窗应显示空白表单,等待用户输入新的令牌信息。

### 实际结果

新建令牌弹窗错误地展示了之前编辑的令牌信息。
2024-03-16 13:23:04 +08:00
JustSong
6325c8e0b4 ci: fix ci 2024-03-15 23:47:54 +08:00
JustSong
b204f6d82b docs: update README 2024-03-15 00:55:28 +08:00
JustSong
752639560f feat: able to use separated database for table logs 2024-03-15 00:30:15 +08:00
JustSong
996f4d99dd ci: fix ci 2024-03-14 23:53:25 +08:00
warjiang
ebfee3b46c feat: add support for private registry in docker-compose.yml (#1103) 2024-03-14 23:47:46 +08:00
dependabot[bot]
3e2e805d61 chore(deps): bump google.golang.org/protobuf from 1.30.0 to 1.33.0 (#1145)
Bumps google.golang.org/protobuf from 1.30.0 to 1.33.0.

---
updated-dependencies:
- dependency-name: google.golang.org/protobuf
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-03-14 23:46:17 +08:00
E.da
3edf7247c4 fix: fix theme berry copy (#1148)
调整berry主题页脚`label`表述
2024-03-14 23:45:50 +08:00
afafw
0926b6206b chore: update client name (#934) 2024-03-14 23:44:46 +08:00
JustSong
7cd57f3125 chore: update ratio for baidu embedding 2024-03-14 23:36:10 +08:00
Jguobao
66efabd5ae fix: fix baidu url check (#1143)
添加百度的另外3个向量模型【"bge-large-zh",
	"bge-large-en",
	"tao-8k",
】
2024-03-14 23:31:07 +08:00
JustSong
8ede66a896 fix: fix ci 2024-03-14 23:27:47 +08:00
JustSong
b169173860 fix: force set Accept header for ali stream request (close #1151) 2024-03-14 23:20:38 +08:00
JustSong
f33555ae78 fix: update max token for test (close #1154) 2024-03-14 23:17:19 +08:00
JustSong
c28ec10795 fix: fix cors for dashboard api 2024-03-14 23:14:39 +08:00
JustSong
e3767cbb07 fix: fix haiku model name (close #1149) 2024-03-14 23:13:05 +08:00
JustSong
be9eb59fbb feat: support lingyiwanwu 2024-03-14 23:11:36 +08:00
JustSong
89e111ac69 ci: fix ci condition 2024-03-14 01:17:19 +08:00
JustSong
2dcef85285 feat: support ollama now (close #870) 2024-03-14 01:02:47 +08:00
JustSong
79d0cd378a fix: fix baidu system prompt (close #1079) 2024-03-13 22:56:54 +08:00
JustSong
e99150bdb9 fix: make quota int64 2024-03-13 20:00:51 +08:00
JustSong
a72e5fcc9e fix: when cached quota is too low, force refresh it 2024-03-13 19:38:44 +08:00
JustSong
0710f8cd66 fix: when cached quota is too low, force refresh it 2024-03-13 19:26:24 +08:00
JustSong
49cad7d4a5 feat: update func ShouldDisableChannel for claude 2024-03-13 19:11:30 +08:00
JustSong
a90161cf00 chore: drop idx_channels_key on start 2024-03-11 02:24:58 +08:00
sparanoid
a45fc7d736 fix: model name typo (#1109) 2024-03-11 00:44:49 +08:00
JustSong
45940dcb12 chore: add more info for panic fix 2024-03-10 23:59:35 +08:00
JustSong
969042b001 chore: only use one log file (close #1116) 2024-03-10 23:44:48 +08:00
JustSong
7e7369dbc4 fix: only disable channel when allowed 2024-03-10 23:41:16 +08:00
JustSong
e54e647170 chore: remove useless code 2024-03-10 23:36:29 +08:00
JustSong
358920c858 fix: remove index idx_channels_key (close #644) 2024-03-10 23:27:22 +08:00
JustSong
1ea598c773 feat: check claude's error response 2024-03-10 20:39:55 +08:00
JustSong
796be42487 feat: update ratio config if missing 2024-03-10 19:29:42 +08:00
JustSong
5b50eb94e5 feat: able to send alert message via message pusher (close #993) 2024-03-10 19:16:06 +08:00
JustSong
71c61365eb feat: able to only test disabled channels (#1090) 2024-03-10 18:34:57 +08:00
JustSong
b09f979b80 fix: add missing turnstile setup (close #1015) 2024-03-10 18:15:24 +08:00
JustSong
12440874b0 feat: able to disable channel by success rate 2024-03-10 17:57:47 +08:00
JustSong
6ebc99460e fix: add user to blacklist when it's banned or deleted, and make deletion soft (close #473, close #791) 2024-03-10 15:56:19 +08:00
JustSong
27ad8bfb98 feat: able to search channel type now 2024-03-10 15:00:33 +08:00
JustSong
8388aa537f chore: able to search channel now 2024-03-10 14:59:57 +08:00
JustSong
2346bf70af fix: check response type when expect stream response 2024-03-10 14:59:40 +08:00
JustSong
f05b403ca5 feat: use real system prompt now (close #1079) 2024-03-10 14:32:30 +08:00
JustSong
b33616df44 feat: support groq now (close #1087) 2024-03-10 14:09:44 +08:00
JustSong
cf16f44970 feat: load channel models from server 2024-03-09 02:28:23 +08:00
JustSong
bf2e26a48f feat: support claude-3 (close #1080, close #1094) 2024-03-09 01:12:47 +08:00
momomobinx
4fb22ad4ce feat: support third part models of baidu (#1046)
百度千帆平台上的第三方大模型调用
2024-03-03 23:50:28 +08:00
JustSong
95cfb8e8c9 fix: using the first available model if default model is not found (close #1021) 2024-03-03 22:58:41 +08:00
JustSong
c6ace985c2 fix: set missing ali parameters (close #1028) 2024-03-03 22:51:01 +08:00
JustSong
10a926b8f3 feat: only use the top priority when first retry (#1048) 2024-03-03 22:16:34 +08:00
JustSong
2df877a352 feat: switch priority when retry (close #1048) 2024-03-03 22:14:07 +08:00
JustSong
9d8967f7d3 feat: support Mistral's models now (close #1051) 2024-03-03 21:46:45 +08:00
JustSong
b35f3523d3 feat: add gemini model alias (close #1064) 2024-03-03 21:03:04 +08:00
JustSong
82e916b5ff fix: fix azure test (close #1069) 2024-03-03 20:51:28 +08:00
JustSong
de18d6fe16 refactor: refactor image relay (close #1068) 2024-03-03 19:30:11 +08:00
JustSong
1d0b7fb5ae feat: support chatglm-4 (close #1045, close #952, close #952, close #943) 2024-03-02 03:05:25 +08:00
JustSong
f9490bb72e fix: able to use updated default ratio 2024-03-02 01:32:04 +08:00
JustSong
76467285e8 docs: update readme 2024-03-02 01:25:21 +08:00
JustSong
df1fd9aa81 feat: support minimax's models now (close #354) 2024-03-02 01:24:28 +08:00
JustSong
614c2e0442 feat: support baichuan's models now (close #1057) 2024-03-02 00:55:48 +08:00
JustSong
eac6a0b9aa fix: fix version is blank 2024-03-02 00:03:29 +08:00
JustSong
b747cdbc6f fix: fix getAndValidateTextRequest failed: unexpected end of JSON input (close #1043) 2024-02-26 22:52:16 +08:00
JustSong
6b27d6659a fix: add role for ChatCompletionsStreamResponseChoice.Delta 2024-02-25 19:49:22 +08:00
JustSong
dc5b781191 fix: fix stream response id 2024-02-25 19:47:59 +08:00
JustSong
c880b4a9a3 fix: fix missing index in ChatCompletionsStreamResponseChoice (#1037) 2024-02-25 19:17:37 +08:00
JustSong
565ea58e68 feat: built in retry supported (close #1036, close #770) 2024-02-25 19:01:49 +08:00
JustSong
f141a37a9e fix: fix "error update user quota cache: Error 1040: Too many connections" 2024-02-25 16:58:14 +08:00
JustSong
5b78886ad3 fix: fix i18n 2024-02-25 16:53:46 +08:00
JustSong
87c7c4f0e6 fix: rm history build before building 2024-02-25 02:07:34 +08:00
JustSong
4c4a873890 fix: add an ending line for THEMES 2024-02-25 01:59:40 +08:00
JustSong
0664bdfda1 fix: fix build.sh (close #1026) 2024-02-25 01:53:27 +08:00
JustSong
32387d9c20 fix: fix version is blank 2024-02-21 22:21:01 +08:00
JustSong
bd888f2eb7 fix: fix prompt token is zero (close #1023) 2024-02-21 22:19:42 +08:00
JustSong
cece77e533 fix: fix model list 2024-02-19 22:20:18 +08:00
JustSong
2a5468e23c refactor: remove useless button (close #1014) 2024-02-18 22:21:37 +08:00
JustSong
d0e415893b fix: fix SparkDesk model name 2024-02-18 17:16:11 +08:00
JustSong
6cf5ce9a7a fix: fix SparkDesk model name 2024-02-18 17:11:16 +08:00
JustSong
f598b9df87 feat: add new SparkDesk models 2024-02-18 17:02:36 +08:00
JustSong
532c50d212 fix: fix channel table page copy 2024-02-18 16:19:14 +08:00
JustSong
2acc2f5017 feat: support moonshot now (close #804) 2024-02-18 16:17:19 +08:00
JustSong
604ac56305 fix: set seed parameter for qwen (close #1005) 2024-02-18 15:01:09 +08:00
JustSong
9383b638a6 feat: add ChatPro & ChatStd for tencent (#1010) 2024-02-18 14:40:01 +08:00
JustSong
28d512a675 refactor: delete useless code 2024-02-18 02:23:31 +08:00
JustSong
de9a58ca0b refactor: use config field to save config 2024-02-18 02:22:50 +08:00
JustSong
1aa374ccfb refactor: use adaptor to do relay & test 2024-02-18 00:15:31 +08:00
Laisky.Cai
d548a01c59 feat: Handle errors, validate model names, and calculate quota usage (#978)
- Improved error handling in various modules for better stability and responsiveness.
- Optimized code in several files for improved efficiency and readability.
- Enhanced user experience by providing more detailed error responses in the controller.
- Strengthened security by ignoring sensitive files in `.gitignore`.
2024-02-12 21:35:40 +08:00
JustSong
2cd1a78203 chore: update module name 2024-01-28 19:38:58 +08:00
JustSong
b9d3cb0c45 refactor: split RelayTextHelper function 2024-01-28 19:14:46 +08:00
JustSong
ea407f0054 feat: able to set completion ration now (close #968) 2024-01-28 16:45:54 +08:00
Benny
26e2e646cb feat: sync models with OpenAI (#971)
* add new 0125 chat models and embedding-3 models

* refine the step of manually deploying

* add gpt-4-turbo-preview
2024-01-28 16:09:21 +08:00
yongman
4f214c48c6 fix: fix primary chat button (#951)
Signed-off-by: yongman <yming0221@gmail.com>
2024-01-21 23:27:34 +08:00
JustSong
2d760d4a01 refactor: refactor relay part (#957)
* refactor: refactor relay part

* refactor: refactor config part
2024-01-21 23:21:42 +08:00
Buer
e2ed0399f0 fix: fix aff not effective (#937) 2024-01-20 12:38:06 +08:00
JustSong
eed9f5fdf0 refactor: refactor relay part (#935) 2024-01-14 19:21:03 +08:00
JustSong
f2c51a494c feat: able to login via email (close #921) 2024-01-14 14:08:39 +08:00
Calcium-Ion
8a4d6f3327 fix: remove useless wrong index (#916)
* add sqlite busy_timeout=3000

* chore: update impl

* fix: fix JSON tag in Log struct

* fix: 修复高并发下,高额度用户使用低额度令牌没有预扣费而导致令牌大额欠费

* Revert "fix: 修复高并发下,高额度用户使用低额度令牌没有预扣费而导致令牌大额欠费"

This reverts commit f0ffe14437.

* fix: remove wrong index

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-14 13:48:16 +08:00
Buer
cf4e33cb12 fix: fix bugs with theme berry (#931)
* fix: home page & logo style issue

* improve: Enhanced user experience by improving the channel selection box

* fix: key cannot be activated after expiration
2024-01-14 13:22:31 +08:00
JustSong
5d60305570 docs: update readme (close #930) 2024-01-14 13:00:59 +08:00
JustSong
d062bc60e4 chore: update ui copy 2024-01-07 19:18:52 +08:00
JustSong
39c1882970 chore: add back THEMES 2024-01-07 19:01:31 +08:00
JustSong
9c42c7dfd9 docs: update theme readme 2024-01-07 18:59:26 +08:00
JustSong
903aaeded0 chore: revert change 2024-01-07 18:47:39 +08:00
JustSong
bdd4be562d chore: add theme validation 2024-01-07 18:44:26 +08:00
Buer
37afb313b5 fix: fix some issues with berry (#913)
* fix: login address error

* fix: Normal users display profile menu

* fix: remove redundant code
2024-01-07 18:39:15 +08:00
JustSong
c9ebcab8b8 fix: fix theme logging 2024-01-07 18:02:59 +08:00
JustSong
86261cc656 feat: able to change theme 2024-01-07 17:53:05 +08:00
JustSong
8491785c9d chore: remove useless logging 2024-01-07 17:13:24 +08:00
JustSong
e848a3f7fa fix: fix error loading stats (#912) 2024-01-07 17:10:59 +08:00
JustSong
318adf5985 chore: remove useless lines 2024-01-07 16:28:44 +08:00
JustSong
965d7fc3d2 fix: fix Dockerfile 2024-01-07 15:33:33 +08:00
JustSong
aa3f605894 fix: fix Dockerfile 2024-01-07 15:17:21 +08:00
JustSong
7b8eff1f22 fix: fix Dockerfile 2024-01-07 15:15:02 +08:00
JustSong
e80cd508ba fix: fix Dockerfile 2024-01-07 15:12:30 +08:00
JustSong
d37f836d53 fix: fix Dockerfile 2024-01-07 15:03:36 +08:00
JustSong
e0b2d1ae47 fix: fix Dockerfile 2024-01-07 14:55:04 +08:00
JustSong
797ead686b fix: fix Dockerfile 2024-01-07 14:42:23 +08:00
JustSong
0d22cf9ead fix: fix Dockerfile 2024-01-07 14:38:02 +08:00
Buer
48989d4a0b feat: add new theme berry (#860)
* feat: add theme berry

* docs: add development notes

* fix: fix blank page

* chore: update implementation

* fix: fix package.json

* chore: update ui copy

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-07 14:20:07 +08:00
Seven Yu
6227eee5bc fix: fix token validation exception handling #901
* fix: fix exception handling

1. add error log for ValidateUserToken
2. update en.json

* chore: update log

---------

Co-authored-by: seven.yu <seven.yu@dji.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-07 13:32:39 +08:00
JustSong
cbf8f07747 docs: fix logo 2024-01-01 21:19:37 +08:00
JustSong
4a96031ce6 docs: update readme 2024-01-01 21:14:45 +08:00
JustSong
92886093ae docs: update readme 2024-01-01 21:10:40 +08:00
JustSong
0c022f17cb chore: update theme related code 2024-01-01 20:25:53 +08:00
JustSong
83f95935de ci: fix Dockerfile & ci 2024-01-01 19:23:46 +08:00
JustSong
aa03c89133 feat: able to add more UI theme (#860) 2024-01-01 18:55:03 +08:00
JustSong
505817ca17 chore: update en.json 2024-01-01 17:46:45 +08:00
JustSong
cb5a3df616 fix: fix pr error (#888) 2024-01-01 17:40:10 +08:00
Laisky.Cai
7772064d87 fix: support base64 encoded image_url (#872)
- Add support for base64 encoded image in OpenAI's image_url

Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-01-01 17:38:35 +08:00
Seven Yu
c50c609565 fix: fix button copywriting (#880)
* feat: rename Channel button

* fix: update en.json

---------

Co-authored-by: seven.yu <seven.yu@dji.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-01 17:09:12 +08:00
Tailen
498dea2dbb feat: add support for davinci-002 and babbage-002 (#888) 2024-01-01 17:06:17 +08:00
Zhanliang Liu
c725cc8842 fix: base 64 encoded format support of gemini-pro-vision for field image_url/url (#878) 2024-01-01 17:00:23 +08:00
Tisfeng
af8908db54 feat: able to change gemini safety setting (#867)
* perf: adjust gemini safety settings, set BLOCK_NONE by default

* feat: able to adjust by env variable

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-01-01 16:42:19 +08:00
JustSong
d8029550f7 fix: do not consume user quota if failed (close #881) 2024-01-01 16:18:50 +08:00
JustSong
f44fbe3fe7 docs: update pr template 2023-12-24 19:24:59 +08:00
JustSong
1c8922153d feat: support gemini-vision-pro 2023-12-24 18:54:32 +08:00
Laisky.Cai
f3c07e1451 fix: openai response should contains model (#841)
* fix: openai response should contains `model`

- Update model attributes in `claudeHandler` for `relay-claude.go`
- Implement model type for fullTextResponse in `relay-gemini.go`
- Add new `Model` field to `OpenAITextResponse` struct in `relay.go`

* chore: set model name response for models

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-24 16:58:31 +08:00
Bryan
40ceb29e54 fix: fix SearchUsers not working if using PostgreSQL (#778)
* fix SearchUsers

* refactor: using UsingPostgreSQL as condition

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-24 16:42:00 +08:00
dependabot[bot]
0699ecd0af chore(deps): bump golang.org/x/crypto from 0.14.0 to 0.17.0 (#840)
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.14.0 to 0.17.0.
- [Commits](https://github.com/golang/crypto/compare/v0.14.0...v0.17.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-12-24 16:29:48 +08:00
moondie
ee9e746520 feat: update ali stream implementation & enable internet search (#856)
* Update relay-ali.go: 改进stream模式,添加联网搜索能力

通义千问支持stream的增量模式,不需要每次去掉上次的前缀;实测qwen-max联网模式效果不错,添加了联网模式。如果别的模型有问题可以改为单独给qwen-max开放

* 删除"stream参数"

刚发现原来阿里api没有这个参数,上次误加了。

* refactor: only enable search when specified

* fix: remove custom suffix when get model ratio

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-24 16:17:21 +08:00
Buer
a763681c2e fix: fix base64 image parse error (#858) 2023-12-24 15:35:56 +08:00
JustSong
b7fcb319da chore: check if SESSION_SECRET equals to random_string 2023-12-20 22:50:50 +08:00
JustSong
67c64e71c8 fix: fix max_tokens check 2023-12-20 21:45:33 +08:00
JustSong
97030e27f8 fix: fix gemini panic (close #833) 2023-12-17 23:30:45 +08:00
JustSong
461f5dab56 docs: update readme 2023-12-17 22:25:03 +08:00
JustSong
af378c59af docs: update readme 2023-12-17 22:19:16 +08:00
ShinChven ✨
bc6769826b feat: add condition to validate n value for non-Azure channels (#775)
- Add a condition to validate the n value only for non-Azure channels, ensuring it falls within the acceptable range.
- Fix Azure compatibility
2023-12-17 19:49:08 +08:00
Oliver Lee
0fe26cc4bd feat: update ali relay implementation (#830)
* 修改通译千问最新接口:1.删除history参数,改用官方推荐的messages参数 2.整理messages参数顺序,补充必要上下文信息 3.用autogen调试测试通过

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-17 19:43:23 +08:00
Calcium-Ion
7d6a169669 feat: able to set sqlite busy_timeout (#818)
* add sqlite busy_timeout=3000

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-17 19:17:00 +08:00
Ghostz
66f06e5d6f feat: reset image num to 1 when not given (#821)
* Update relay-image.go

* fix: reset image num to 1 when not given

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-17 18:54:08 +08:00
JustSong
6acb9537a9 fix: try to return a more meaningful error message (close #817) 2023-12-17 18:33:27 +08:00
JustSong
7069c49bdf fix: fix xunfei panic error (close #820) 2023-12-17 18:06:37 +08:00
JustSong
58dee76bf7 fix: fix Gemini stream problem 2023-12-17 16:16:18 +08:00
David Zhuang
5cf23d8698 feat: add Google Gemini Pro support (#826)
* fest: Add Google Gemini Pro, fix #810

* fest: Add tooling to Gemini; Add OpenAI-like system prompt to Gemini

* refactor: removing unused if statement

* fest: Add dummy model message for system message in gemini model

* chore: update implementation

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-17 12:48:32 +08:00
JustSong
366b82128f fix: remove incorrect logging 2023-12-10 20:44:37 +08:00
JustSong
2a70744dbf feat: add panic recover middleware 2023-12-10 19:53:33 +08:00
Qiying Wang
4c5feee0b6 feat: add image counter for gpt-4 vision (#795) 2023-12-10 19:39:46 +08:00
igophper
9ba5388367 feat: refactor response parsing logic to support multiple formats (#782)
* feat: Refactor response parsing logic to support multiple formats

The parsing logic for responses in relay.go and relay-audio.go was refactored to support multiple response formats - 'json', 'text', 'srt', 'verbose_json', and 'vtt'. The existing `WhisperResponse` struct was renamed to `WhisperJsonResponse` and a new struct `WhisperVerboseJsonResponse` was added to support the 'verbose_json' format. Additional parsing functions were added to extract text from these new response types. This change was necessary to make the parsing logic more flexible and extendable for different types of responses.

* chore: update name

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-10 18:39:14 +08:00
JustSong
379074f7d0 feat: support plugin for ali channel (close #797) 2023-12-10 17:22:52 +08:00
JustSong
01f7b0186f chore: add routes 2023-12-03 20:45:11 +08:00
Tillman Bailee
a3f80a3392 feat: enable channel when test succeed (#771)
* 增加功能: 渠道 - 测试所有通道; 设置 - 运营设置 - 监控设置 - 成功时自动启用通道

* refactor: update implementation

---------

Co-authored-by: liyujie <29959257@qq.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-03 20:10:57 +08:00
Zhengyi Dong
8f5b83562b fix: fix "invalidPayload" error when request Azure dall-e-3 api without optional parameter (#764)
* fix: based on #754 add 'omitempty' in ImageRequest to fit official api reference for relay

* Revert "fix: based on #754 add 'omitempty' in ImageRequest to fit official api reference for relay"

This reverts commit b526006ce0.

* fix: add missing omitempty

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-03 17:43:30 +08:00
ShinChven ✨
b7570d5c77 feat: support dalle for Azure (#754)
* feat: Add Message-ID to email headers to comply with RFC 5322

- Extract domain from SMTPFrom
- Generate a unique Message-ID
- Add Message-ID to email headers

* chore: check slice length

* feat: Add Azure compatibility for relayImageHelper

- Handle Azure channel requestURL compatibility
- Set api-key header for Azure channel authentication
- Handle Azure channel request body

fixes: https://github.com/songquanpeng/one-api/issues/751

* refactor: update implementation

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-12-03 17:34:59 +08:00
JustSong
0e73418cdf fix: fix log recording & error handling for relay-audio 2023-11-26 12:05:16 +08:00
JustSong
9889377f0e feat: support claude-2.x (close #736) 2023-11-24 21:39:44 +08:00
JustSong
b273464e77 docs: update readme 2023-11-24 21:23:16 +08:00
JustSong
b4e43d97fd docs: add pr template 2023-11-24 21:21:03 +08:00
Ian Li
3347a44023 feat: support Azure's Whisper model (#720) 2023-11-24 21:10:18 +08:00
Tillman Bailee
923e24534b fix: add Date header for email (#742)
* 修复自建邮箱发送错误: INVALID HEADER Missing required header field: "Date"

* chore: fix style

---------

Co-authored-by: liyujie <29959257@qq.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-24 20:56:53 +08:00
ShinChven ✨
b4d67ca614 fix: add Message-ID header for email (#732)
* feat: Add Message-ID to email headers to comply with RFC 5322

- Extract domain from SMTPFrom
- Generate a unique Message-ID
- Add Message-ID to email headers

* chore: check slice length

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-24 20:52:59 +08:00
igophper
d85e356b6e refactor: remove consumeQuota related logic (#738)
* feat: 删除relay-text中的consumeQuota变量

该变量始终为true,可以删除

* chore: remove useless code

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-24 20:42:29 +08:00
JustSong
495fc628e4 feat: support gpt-4 with vision (#683, #714) 2023-11-19 18:38:54 +08:00
JustSong
76f9288c34 feat: update request struct (close #708) 2023-11-19 17:50:30 +08:00
JustSong
915d13fdd4 docs: update readme (#724) 2023-11-19 17:22:35 +08:00
Ian Li
969f539777 fix: skip JSON deserialization when accessing transcriptions and translations (#718)
* fix: Skip JSON deserialization when accessing transcriptions and translations.

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-19 16:11:39 +08:00
Buer
54e5f8ecd2 feat: support cloudflare gateway for azure (#666)
* 🐛 Fix cloudflare gateway request failure

* 🐛 fix channel test url error
2023-11-19 15:52:35 +08:00
Mikey
34d517cfa2 fix: cloudflare test & expose detailed info about test failures (#715)
* fix: cloudflare test & expose detailed info about test failures

* fix: cloudflare test & expose detailed info about test failures

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-17 21:45:55 +08:00
ckt1031
ddcaf95f5f feat: support tts model (#713)
* Added support for Text-to-Speech models and
endpoints

* chore: update impl

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-17 21:18:51 +08:00
ckt1031
1d15157f7d feat: keep sync with dall-e updates (#679)
* Updated ImageRequest struct and OpenAIModels,
added new Dall-E models and size ratios

* Fixed suspect `or`

* Refactored size ratio calculation in
relayImageHelper function

* Updated the format of resolution keys in
DalleSizeRatios map

* Added error handling for unsupported image size in
relayImageHelper function

* Added validation for number of generated images
and defined image generation ratios

* Refactored variable name from
DalleGenerationImageAmountRatios to
DalleGenerationImageAmounts

* Added validation for prompt length in
relayImageHelper function

* Updated model validation and removed size not
supported error in relayImageHelper function

* Refactored image size and model validation in
relayImageHelper function

* chore: discard binary file

* chore: update impl

---------

Co-authored-by: cktsun1031 <65409152+cktsun1031@users.noreply.github.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-17 20:03:16 +08:00
管宜尧
de7b9710a5 fix: fix PaLM not working issue (#667)
* bugfix for #515 最新版本谷歌PaLM模型无法使用

* update

* chore: remove unrelated file

* chore: add comment

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-17 19:40:59 +08:00
Dafei Zhao
58bb3ab6f6 fix: fix channel_id column name (#681, close #688) 2023-11-10 21:50:52 +08:00
qingfengfenga
d306cb5229 feat: add improve docker-compose.yml and support fast startup (#685)
Co-authored-by: 王彦朋 Penn Wang <penn.wang@digitwin.com.cn>
2023-11-10 21:40:00 +08:00
Yuhang
6c5307d0c4 docs: add deploy to zeabur button (#693)
* Update README.md

* Update README.en.md

* Update README.ja.md
2023-11-10 21:20:59 +08:00
Baksi
7c4505bdfc fix: numeric sorting in tables (#695)
* Update sorting method for id

* Update sorting method for id (token)

* Update sorting method for id (redemptions)

* Update sorting method for id (channel)

* chore: use same logic for all tables

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-10 21:20:05 +08:00
Mikey
9d43ec57d8 feat: sync pricing for new 1106 models (#696)
* feat: sync pricing for new 1106 models

* chore: change ratio after 2023-12-11

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-10 21:08:23 +08:00
JustSong
e5311892d1 docs: update readme 2023-11-08 23:17:12 +08:00
wzxjohn
bc7c9105f4 chore: update quota calc logic (close #599) (#627)
* fix: change quota calc code (close #599)

Use float64 during calc and do math.Ceil after calc. This will result in the quota being used slightly more than the official standard, but it will be guaranteed that it will not be less.

* chore: remove blank line

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-05 19:15:06 +08:00
wood chen
3fe76c8af7 fix: fix Cloudflare AI Gateway channel test support (#639)
* 当使用Cloudflare AI Gateway时,支持openai渠道测试

* refactor: change logic

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-11-05 19:08:25 +08:00
papersnake
c70c614018 feat: support chatglm_turbo (#648)
* feat: support chatglm_turbo

* fix: remove characterglm
2023-11-05 17:59:38 +08:00
Baksi
0d87de697c fix: fix typo (#651) 2023-11-02 22:24:22 +08:00
MaricoHan
aec343dc38 feat: support xunfei v3 (#637) 2023-10-29 22:03:01 +08:00
448 changed files with 32582 additions and 4738 deletions

View File

@@ -20,6 +20,13 @@ jobs:
- name: Check out the repo - name: Check out the repo
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info - name: Save version info
run: | run: |
git describe --tags > VERSION git describe --tags > VERSION

View File

@@ -20,6 +20,13 @@ jobs:
- name: Check out the repo - name: Check out the repo
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info - name: Save version info
run: | run: |
git describe --tags > VERSION git describe --tags > VERSION

View File

@@ -21,6 +21,13 @@ jobs:
- name: Check out the repo - name: Check out the repo
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info - name: Save version info
run: | run: |
git describe --tags > VERSION git describe --tags > VERSION

View File

@@ -7,6 +7,11 @@ on:
tags: tags:
- '*' - '*'
- '!*-alpha*' - '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs: jobs:
release: release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -15,6 +20,12 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3
with: with:
node-version: 16 node-version: 16
@@ -23,8 +34,8 @@ jobs:
CI: "" CI: ""
run: | run: |
cd web cd web
npm install git describe --tags > VERSION
REACT_APP_VERSION=$(git describe --tags) npm run build REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh
cd .. cd ..
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
@@ -33,7 +44,7 @@ jobs:
- name: Build Backend (amd64) - name: Build Backend (amd64)
run: | run: |
go mod download go mod download
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
- name: Build Backend (arm64) - name: Build Backend (arm64)
run: | run: |

View File

@@ -7,6 +7,11 @@ on:
tags: tags:
- '*' - '*'
- '!*-alpha*' - '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs: jobs:
release: release:
runs-on: macos-latest runs-on: macos-latest
@@ -15,6 +20,12 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3
with: with:
node-version: 16 node-version: 16
@@ -23,8 +34,8 @@ jobs:
CI: "" CI: ""
run: | run: |
cd web cd web
npm install git describe --tags > VERSION
REACT_APP_VERSION=$(git describe --tags) npm run build REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh
cd .. cd ..
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
@@ -33,7 +44,7 @@ jobs:
- name: Build Backend - name: Build Backend
run: | run: |
go mod download go mod download
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos
- name: Release - name: Release
uses: softprops/action-gh-release@v1 uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')

View File

@@ -7,6 +7,11 @@ on:
tags: tags:
- '*' - '*'
- '!*-alpha*' - '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs: jobs:
release: release:
runs-on: windows-latest runs-on: windows-latest
@@ -18,6 +23,12 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3
with: with:
node-version: 16 node-version: 16
@@ -25,10 +36,10 @@ jobs:
env: env:
CI: "" CI: ""
run: | run: |
cd web cd web/default
npm install npm install
REACT_APP_VERSION=$(git describe --tags) npm run build REACT_APP_VERSION=$(git describe --tags) npm run build
cd .. cd ../..
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
@@ -36,7 +47,7 @@ jobs:
- name: Build Backend - name: Build Backend
run: | run: |
go mod download go mod download
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe
- name: Release - name: Release
uses: softprops/action-gh-release@v1 uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')

2
.gitignore vendored
View File

@@ -6,3 +6,5 @@ upload
build build
*.db-journal *.db-journal
logs logs
data
/web/node_modules

View File

@@ -1,10 +1,19 @@
FROM node:16 as builder FROM node:16 as builder
WORKDIR /build WORKDIR /web
COPY web/package.json .
RUN npm install
COPY ./web .
COPY ./VERSION . COPY ./VERSION .
COPY ./web .
WORKDIR /web/default
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
WORKDIR /web/berry
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
WORKDIR /web/air
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2 FROM golang AS builder2
@@ -17,8 +26,8 @@ WORKDIR /build
ADD go.mod go.sum ./ ADD go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
COPY --from=builder /build/build ./web/build COPY --from=builder /web/build ./web/build
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api RUN go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine FROM alpine

View File

@@ -3,7 +3,7 @@
</p> </p>
<p align="center"> <p align="center">
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p> </p>
<div align="center"> <div align="center">
@@ -60,7 +60,7 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use
1. Support for multiple large models: 1. Support for multiple large models:
+ [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
+ [x] [Anthropic Claude Series Models](https://anthropic.com) + [x] [Anthropic Claude Series Models](https://anthropic.com)
+ [x] [Google PaLM2 Series Models](https://developers.generativeai.google) + [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google)
+ [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn)
@@ -134,12 +134,12 @@ The initial account username is `root` and password is `123456`.
git clone https://github.com/songquanpeng/one-api.git git clone https://github.com/songquanpeng/one-api.git
# Build the frontend # Build the frontend
cd one-api/web cd one-api/web/default
npm install npm install
npm run build npm run build
# Build the backend # Build the backend
cd .. cd ../..
go mod download go mod download
go build -ldflags "-s -w" -o one-api go build -ldflags "-s -w" -o one-api
``` ```
@@ -189,6 +189,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
> Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage.
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. First, fork the code. 1. First, fork the code.
2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console.
3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port).
@@ -239,17 +241,19 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Example: `SESSION_SECRET=random_string` + Example: `SESSION_SECRET=random_string`
3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0. 3. `SQL_DSN`: When set, the specified database will be used instead of SQLite. Please use MySQL version 8.0.
+ Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` + Example: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
4. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address. 4. `LOG_SQL_DSN`: When set, a separate database will be used for the `logs` table; please use MySQL or PostgreSQL.
+ Example: `LOG_SQL_DSN=root:123456@tcp(localhost:3306)/oneapi-logs`
5. `FRONTEND_BASE_URL`: When set, the specified frontend address will be used instead of the backend address.
+ Example: `FRONTEND_BASE_URL=https://openai.justsong.cn` + Example: `FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen. 6. `SYNC_FREQUENCY`: When set, the system will periodically sync configurations from the database, with the unit in seconds. If not set, no sync will happen.
+ Example: `SYNC_FREQUENCY=60` + Example: `SYNC_FREQUENCY=60`
6. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`. 7. `NODE_TYPE`: When set, specifies the node type. Valid values are `master` and `slave`. If not set, it defaults to `master`.
+ Example: `NODE_TYPE=slave` + Example: `NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen. 8. `CHANNEL_UPDATE_FREQUENCY`: When set, it periodically updates the channel balances, with the unit in minutes. If not set, no update will happen.
+ Example: `CHANNEL_UPDATE_FREQUENCY=1440` + Example: `CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen. 9. `CHANNEL_TEST_FREQUENCY`: When set, it periodically tests the channels, with the unit in minutes. If not set, no test will happen.
+ Example: `CHANNEL_TEST_FREQUENCY=1440` + Example: `CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval. 10. `POLLING_INTERVAL`: The time interval (in seconds) between requests when updating channel balances and testing channel availability. Default is no interval.
+ Example: `POLLING_INTERVAL=5` + Example: `POLLING_INTERVAL=5`
### Command Line Parameters ### Command Line Parameters

View File

@@ -3,7 +3,7 @@
</p> </p>
<p align="center"> <p align="center">
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p> </p>
<div align="center"> <div align="center">
@@ -60,7 +60,7 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に
1. 複数の大型モデルをサポート: 1. 複数の大型モデルをサポート:
+ [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート)
+ [x] [Anthropic Claude シリーズモデル](https://anthropic.com) + [x] [Anthropic Claude シリーズモデル](https://anthropic.com)
+ [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google) + [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google)
+ [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn)
@@ -135,12 +135,12 @@ sudo service nginx restart
git clone https://github.com/songquanpeng/one-api.git git clone https://github.com/songquanpeng/one-api.git
# フロントエンドのビルド # フロントエンドのビルド
cd one-api/web cd one-api/web/default
npm install npm install
npm run build npm run build
# バックエンドのビルド # バックエンドのビルド
cd .. cd ../..
go mod download go mod download
go build -ldflags "-s -w" -o one-api go build -ldflags "-s -w" -o one-api
``` ```
@@ -190,6 +190,8 @@ Please refer to the [environment variables](#environment-variables) section for
> Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。 > Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. まず、コードをフォークする。 1. まず、コードをフォークする。
2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。
3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。
@@ -240,17 +242,18 @@ graph LR
+ 例: `SESSION_SECRET=random_string` + 例: `SESSION_SECRET=random_string`
3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。 3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。
+ 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる 4. `LOG_SQL_DSN`: 設定ると、`logs`テーブルには独立したデータベースが使用されます。MySQLまたはPostgreSQLを使用してください
5. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。
+ 例: `FRONTEND_BASE_URL=https://openai.justsong.cn` + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。 6. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。
+ 例: `SYNC_FREQUENCY=60` + 例: `SYNC_FREQUENCY=60`
6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master``slave` である。設定されていない場合、デフォルトは `master` 7. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master``slave` である。設定されていない場合、デフォルトは `master`
+ 例: `NODE_TYPE=slave` + 例: `NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。 8. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。
+ 例: `CHANNEL_UPDATE_FREQUENCY=1440` + 例: `CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。 9. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。
+ 例: `CHANNEL_TEST_FREQUENCY=1440` + 例: `CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。 10. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。
+ 例: `POLLING_INTERVAL=5` + 例: `POLLING_INTERVAL=5`
### コマンドラインパラメータ ### コマンドラインパラメータ

View File

@@ -4,7 +4,7 @@
<p align="center"> <p align="center">
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a> <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p> </p>
<div align="center"> <div align="center">
@@ -51,60 +51,64 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
<a href="https://iamazing.cn/page/reward">赞赏支持</a> <a href="https://iamazing.cn/page/reward">赞赏支持</a>
</p> </p>
> **Note** > [!NOTE]
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> >
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> **Warning** > [!WARNING]
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
> **Warning** > [!WARNING]
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456` > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`
## 功能 ## 功能
1. 支持多种大模型: 1. 支持多种大模型:
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
+ [x] [Anthropic Claude 系列模型](https://anthropic.com) + [x] [Anthropic Claude 系列模型](https://anthropic.com)
+ [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
+ [x] [Mistral 系列模型](https://mistral.ai/)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
+ [x] [360 智脑](https://ai.360.cn) + [x] [360 智脑](https://ai.360.cn)
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
2. 支持配置镜像以及众多第三方代理服务: + [x] [Moonshot AI](https://platform.moonshot.cn/)
+ [x] [OpenAI-SB](https://openai-sb.com) + [x] [百川大模型](https://platform.baichuan-ai.com)
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412) + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
+ [x] [API2D](https://api2d.com/r/197971) + [x] [MINIMAX](https://api.minimax.chat/)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) + [x] [Groq](https://wow.groq.com/)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI` + [x] [Ollama](https://github.com/ollama/ollama)
+ [x] 自定义渠道:例如各种未收录的第三方代理服务 + [x] [零一万物](https://platform.lingyiwanwu.com/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。 3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
5. 支持**多机部署**[详见此处](#多机部署)。 5. 支持**多机部署**[详见此处](#多机部署)。
6. 支持**令牌管理**,设置令牌的过期时间和额度。 6. 支持**令牌管理**,设置令牌的过期时间和额度。
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
8. 支持**道管理**,批量创建道。 8. 支持**道管理**,批量创建道。
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
10. 支持渠道**设置模型列表**。 10. 支持渠道**设置模型列表**。
11. 支持**查看额度明细**。 11. 支持**查看额度明细**。
12. 支持**用户邀请奖励**。 12. 支持**用户邀请奖励**。
13. 支持以美元为单位显示额度。 13. 支持以美元为单位显示额度。
14. 支持发布公告,设置充值链接,设置新用户初始额度。 14. 支持发布公告,设置充值链接,设置新用户初始额度。
15. 支持模型映射,重定向用户的请求模型。 15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功
16. 支持失败自动重试。 16. 支持失败自动重试。
17. 支持绘图接口。 17. 支持绘图接口。
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
19. 支持丰富的**自定义**设置, 19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。 1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
20. 支持通过系统访问令牌访问管理 API。 20. 支持通过系统访问令牌访问管理 APIbearer token用以替代 cookie你可以自行抓包来查看 API 的用法)
21. 支持 Cloudflare Turnstile 用户校验。 21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式** 22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。 + [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
## 部署 ## 部署
### 基于 Docker 进行部署 ### 基于 Docker 进行部署
@@ -160,18 +164,31 @@ sudo service nginx restart
初始账号用户名为 `root`,密码为 `123456` 初始账号用户名为 `root`,密码为 `123456`
### 基于 Docker Compose 进行部署
> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分
```shell
# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内
docker-compose up -d
# 查看部署状态
docker-compose ps
```
### 手动部署 ### 手动部署
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
```shell ```shell
git clone https://github.com/songquanpeng/one-api.git git clone https://github.com/songquanpeng/one-api.git
# 构建前端 # 构建前端
cd one-api/web cd one-api/web/default
npm install npm install
npm run build npm run build
# 构建后端 # 构建后端
cd .. cd ../..
go mod download go mod download
go build -ldflags "-s -w" -o one-api go build -ldflags "-s -w" -o one-api
```` ````
@@ -249,6 +266,8 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. 首先 fork 一份代码。 1. 首先 fork 一份代码。
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
3. 新建一个 Project在 Service -> Add Service 选择 Marketplace选择 MySQL并记下连接参数用户名、密码、地址、端口 3. 新建一个 Project在 Service -> Add Service 选择 Marketplace选择 MySQL并记下连接参数用户名、密码、地址、端口
@@ -330,32 +349,40 @@ graph LR
+ `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。
+ 如果报错 `Error 1040: Too many connections`,请适当减小该值。 + 如果报错 `Error 1040: Too many connections`,请适当减小该值。
+ `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置 4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL
5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`MEMORY_CACHE_ENABLED=true` + 例子:`MEMORY_CACHE_ENABLED=true`
6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
+ 例子:`SYNC_FREQUENCY=60` + 例子:`SYNC_FREQUENCY=60`
7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
+ 例子:`NODE_TYPE=slave` + 例子:`NODE_TYPE=slave`
8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440` + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+ 例子:`CHANNEL_TEST_FREQUENCY=1440` + 例子:`CHANNEL_TEST_FREQUENCY=1440`
10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5` + 例子:`POLLING_INTERVAL=5`
11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true` + 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5` + 例子:`BATCH_UPDATE_INTERVAL=5`
13. 请求频率限制: 14. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
14. 编码器缓存设置: 15. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
17. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
18. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
19. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
20. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
21. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
22. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
23. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
@@ -394,13 +421,16 @@ https://openai.justsong.cn
+ 检查你的接口地址和 API Key 有没有填对。 + 检查你的接口地址和 API Key 有没有填对。
+ 检查是否启用了 HTTPS浏览器会拦截 HTTPS 域名下的 HTTP 请求。 + 检查是否启用了 HTTPS浏览器会拦截 HTTPS 域名下的 HTTP 请求。
6. 报错:`当前分组负载已饱和,请稍后再试` 6. 报错:`当前分组负载已饱和,请稍后再试`
+ 上游道 429 了。 + 上游道 429 了。
7. 升级之后我的数据会丢失吗? 7. 升级之后我的数据会丢失吗?
+ 如果使用 MySQL不会。 + 如果使用 MySQL不会。
+ 如果使用 SQLite需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 + 如果使用 SQLite需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
8. 升级之前数据库需要做变更吗? 8. 升级之前数据库需要做变更吗?
+ 一般情况下不需要,系统将在初始化的时候自动调整。 + 一般情况下不需要,系统将在初始化的时候自动调整。
+ 如果需要的话,我会在更新日志中说明,并给出脚本。 + 如果需要的话,我会在更新日志中说明,并给出脚本。
9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`
+ 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。
+ 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。
## 相关项目 ## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统

29
common/blacklist/main.go Normal file
View File

@@ -0,0 +1,29 @@
package blacklist
import (
"fmt"
"sync"
)
var blackList sync.Map
func init() {
blackList = sync.Map{}
}
func userId2Key(id int) string {
return fmt.Sprintf("userid_%d", id)
}
func BanUser(id int) {
blackList.Store(userId2Key(id), true)
}
func UnbanUser(id int) {
blackList.Delete(userId2Key(id))
}
func IsUserBanned(id int) bool {
_, ok := blackList.Load(userId2Key(id))
return ok
}

140
common/config/config.go Normal file
View File

@@ -0,0 +1,140 @@
package config
import (
"github.com/songquanpeng/one-api/common/env"
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var MessagePusherAddress = ""
var MessagePusherToken = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser int64 = 0
var QuotaForInviter int64 = 0
var QuotaForInvitee int64 = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold int64 = 1000
var PreConsumedQuota int64 = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = env.String("THEME", "default")
var ValidThemes = map[string]bool{
"default": true,
"berry": true,
"air": true,
}
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
var EnableMetric = env.Bool("ENABLE_METRIC", false)
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")

View File

@@ -1,105 +1,9 @@
package common package common
import ( import "time"
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
var StartTime = time.Now().Unix() // unit: second var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser = 0
var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
const ( const (
RoleGuestUser = 0 RoleGuestUser = 0
@@ -108,37 +12,10 @@ const (
RoleRootUser = 100 RoleRootUser = 100
) )
var (
FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser
ImageUploadPermission = RoleGuestUser
ImageDownloadPermission = RoleGuestUser
)
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
const ( const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value! UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0 UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
) )
const ( const (
@@ -162,30 +39,40 @@ const (
) )
const ( const (
ChannelTypeUnknown = 0 ChannelTypeUnknown = iota
ChannelTypeOpenAI = 1 ChannelTypeOpenAI
ChannelTypeAPI2D = 2 ChannelTypeAPI2D
ChannelTypeAzure = 3 ChannelTypeAzure
ChannelTypeCloseAI = 4 ChannelTypeCloseAI
ChannelTypeOpenAISB = 5 ChannelTypeOpenAISB
ChannelTypeOpenAIMax = 6 ChannelTypeOpenAIMax
ChannelTypeOhMyGPT = 7 ChannelTypeOhMyGPT
ChannelTypeCustom = 8 ChannelTypeCustom
ChannelTypeAILS = 9 ChannelTypeAILS
ChannelTypeAIProxy = 10 ChannelTypeAIProxy
ChannelTypePaLM = 11 ChannelTypePaLM
ChannelTypeAPI2GPT = 12 ChannelTypeAPI2GPT
ChannelTypeAIGC2D = 13 ChannelTypeAIGC2D
ChannelTypeAnthropic = 14 ChannelTypeAnthropic
ChannelTypeBaidu = 15 ChannelTypeBaidu
ChannelTypeZhipu = 16 ChannelTypeZhipu
ChannelTypeAli = 17 ChannelTypeAli
ChannelTypeXunfei = 18 ChannelTypeXunfei
ChannelType360 = 19 ChannelType360
ChannelTypeOpenRouter = 20 ChannelTypeOpenRouter
ChannelTypeAIProxyLibrary = 21 ChannelTypeAIProxyLibrary
ChannelTypeFastGPT = 22 ChannelTypeFastGPT
ChannelTypeTencent = 23 ChannelTypeTencent
ChannelTypeGemini
ChannelTypeMoonshot
ChannelTypeBaichuan
ChannelTypeMinimax
ChannelTypeMistral
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
ChannelTypeDummy
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
@@ -200,7 +87,7 @@ var ChannelBaseURLs = []string{
"", // 8 "", // 8
"https://api.caipacity.com", // 9 "https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10 "https://api.aiproxy.io", // 10
"", // 11 "https://generativelanguage.googleapis.com", // 11
"https://api.api2gpt.com", // 12 "https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13 "https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14 "https://api.anthropic.com", // 14
@@ -213,4 +100,20 @@ var ChannelBaseURLs = []string{
"https://api.aiproxy.io", // 21 "https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22 "https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23 "https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
"https://api.moonshot.cn", // 25
"https://api.baichuan-ai.com", // 26
"https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
"https://api.lingyiwanwu.com", // 31
} }
const (
ConfigKeyPrefix = "cfg_"
ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
)

6
common/conv/any.go Normal file
View File

@@ -0,0 +1,6 @@
package conv
func AsString(v any) string {
str, _ := v.(string)
return str
}

View File

@@ -1,6 +1,12 @@
package common package common
import (
"github.com/songquanpeng/one-api/common/env"
)
var UsingSQLite = false var UsingSQLite = false
var UsingPostgreSQL = false var UsingPostgreSQL = false
var UsingMySQL = false
var SQLitePath = "one-api.db" var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)

View File

@@ -1,67 +0,0 @@
package common
import (
"crypto/tls"
"encoding/base64"
"fmt"
"net/smtp"
"strings"
)
func SendEmail(subject string, receiver string, content string) error {
if SMTPFrom == "" { // for compatibility
SMTPFrom = SMTPAccount
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")
var err error
if SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: SMTPServer,
}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
if err != nil {
return err
}
client, err := smtp.NewClient(conn, SMTPServer)
if err != nil {
return err
}
defer client.Close()
if err = client.Auth(auth); err != nil {
return err
}
if err = client.Mail(SMTPFrom); err != nil {
return err
}
receiverEmails := strings.Split(receiver, ";")
for _, receiver := range receiverEmails {
if err = client.Rcpt(receiver); err != nil {
return err
}
}
w, err := client.Data()
if err != nil {
return err
}
_, err = w.Write(mail)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
} else {
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
}
return err
}

View File

@@ -15,10 +15,7 @@ type embedFileSystem struct {
func (e embedFileSystem) Exists(prefix string, path string) bool { func (e embedFileSystem) Exists(prefix string, path string) bool {
_, err := e.Open(path) _, err := e.Open(path)
if err != nil { return err == nil
return false
}
return true
} }
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {

42
common/env/helper.go vendored Normal file
View File

@@ -0,0 +1,42 @@
package env
import (
"os"
"strconv"
)
func Bool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env) == "true"
}
func Int(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
return defaultValue
}
return num
}
func Float64(env string, defaultValue float64) float64 {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.ParseFloat(os.Getenv(env), 64)
if err != nil {
return defaultValue
}
return num
}
func String(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}

View File

@@ -5,18 +5,37 @@ import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"strings"
) )
func UnmarshalBodyReusable(c *gin.Context, v any) error { const KeyRequestBody = "key_request_body"
func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(KeyRequestBody)
if requestBody != nil {
return requestBody.([]byte), nil
}
requestBody, err := io.ReadAll(c.Request.Body) requestBody, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
return err return nil, err
} }
err = c.Request.Body.Close() _ = c.Request.Body.Close()
c.Set(KeyRequestBody, requestBody)
return requestBody.([]byte), nil
}
func UnmarshalBodyReusable(c *gin.Context, v any) error {
requestBody, err := GetRequestBody(c)
if err != nil { if err != nil {
return err return err
} }
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v) err = json.Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
}
if err != nil { if err != nil {
return err return err
} }
@@ -24,3 +43,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil return nil
} }
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

View File

@@ -1,6 +1,9 @@
package common package common
import "encoding/json" import (
"encoding/json"
"github.com/songquanpeng/one-api/common/logger"
)
var GroupRatio = map[string]float64{ var GroupRatio = map[string]float64{
"default": 1, "default": 1,
@@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{
func GroupRatio2JSONString() string { func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio) jsonBytes, err := json.Marshal(GroupRatio)
if err != nil { if err != nil {
SysError("error marshalling model ratio: " + err.Error()) logger.SysError("error marshalling model ratio: " + err.Error())
} }
return string(jsonBytes) return string(jsonBytes)
} }
@@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error {
func GetGroupRatio(name string) float64 { func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name] ratio, ok := GroupRatio[name]
if !ok { if !ok {
SysError("group ratio not found: " + name) logger.SysError("group ratio not found: " + name)
return 1 return 1
} }
return ratio return ratio

217
common/helper/helper.go Normal file
View File

@@ -0,0 +1,217 @@
package helper
import (
"fmt"
"github.com/google/uuid"
"html/template"
"log"
"math/rand"
"net"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
)
func OpenBrowser(url string) {
var err error
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
} else {
numStr = fmt.Sprintf("%d", num)
}
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter := inter.(type) {
case string:
return inter
case int:
return fmt.Sprintf("%d", inter)
case float64:
return fmt.Sprintf("%f", inter)
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func GenRequestID() string {
return GetTimeString() + GetRandomNumberString(8)
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func AssignOrDefault(value string, defaultValue string) string {
if len(value) != 0 {
return value
}
return defaultValue
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

111
common/image/image.go Normal file
View File

@@ -0,0 +1,111 @@
package image
import (
"bytes"
"encoding/base64"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"net/http"
"regexp"
"strings"
"sync"
_ "golang.org/x/image/webp"
)
// Regex to match data URL pattern
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
img, _, err := image.DecodeConfig(resp.Body)
if err != nil {
return
}
return img.Width, img.Height, nil
}
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
// Check if the URL is a data URL
matches := dataURLPattern.FindStringSubmatch(url)
if len(matches) == 3 {
// URL is a data URL
mimeType = "image/" + matches[1]
data = matches[2]
return
}
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}
var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
)
var readerPool = sync.Pool{
New: func() interface{} {
return &bytes.Reader{}
},
}
func GetImageSizeFromBase64(encoded string) (width int, height int, err error) {
decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, ""))
if err != nil {
return 0, 0, err
}
reader := readerPool.Get().(*bytes.Reader)
defer readerPool.Put(reader)
reader.Reset(decoded)
img, _, err := image.DecodeConfig(reader)
if err != nil {
return 0, 0, err
}
return img.Width, img.Height, nil
}
func GetImageSize(image string) (width int, height int, err error) {
if strings.HasPrefix(image, "data:image/") {
return GetImageSizeFromBase64(image)
}
return GetImageSizeFromUrl(image)
}

171
common/image/image_test.go Normal file
View File

@@ -0,0 +1,171 @@
package image_test
import (
"encoding/base64"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"strconv"
"strings"
"testing"
img "github.com/songquanpeng/one-api/common/image"
"github.com/stretchr/testify/assert"
_ "golang.org/x/image/webp"
)
type CountingReader struct {
reader io.Reader
BytesRead int
}
func (r *CountingReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.BytesRead += n
return n, err
}
var (
cases = []struct {
url string
format string
width int
height int
}{
{"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669},
{"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592},
{"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985},
{"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533},
{"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230},
}
)
func TestDecode(t *testing.T) {
// Bytes read: varies sometimes
// jpeg: 1063892
// png: 294462
// webp: 99529
// gif: 956153
// jpeg#01: 32805
for _, c := range cases {
t.Run("Decode:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
reader := &CountingReader{reader: resp.Body}
img, format, err := image.Decode(reader)
assert.NoError(t, err)
size := img.Bounds().Size()
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, size.X)
assert.Equal(t, c.height, size.Y)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
// Bytes read:
// jpeg: 4096
// png: 4096
// webp: 4096
// gif: 4096
// jpeg#01: 4096
for _, c := range cases {
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
reader := &CountingReader{reader: resp.Body}
config, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, config.Width)
assert.Equal(t, c.height, config.Height)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
}
func TestBase64(t *testing.T) {
// Bytes read:
// jpeg: 1063892
// png: 294462
// webp: 99072
// gif: 953856
// jpeg#01: 32805
for _, c := range cases {
t.Run("Decode:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
reader := &CountingReader{reader: body}
img, format, err := image.Decode(reader)
assert.NoError(t, err)
size := img.Bounds().Size()
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, size.X)
assert.Equal(t, c.height, size.Y)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
// Bytes read:
// jpeg: 1536
// png: 768
// webp: 768
// gif: 1536
// jpeg#01: 3840
for _, c := range cases {
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
reader := &CountingReader{reader: body}
config, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, config.Width)
assert.Equal(t, c.height, config.Height)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
}
func TestGetImageSize(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
width, height, err := img.GetImageSize(c.url)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}
func TestGetImageSizeFromBase64(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
width, height, err := img.GetImageSizeFromBase64(encoded)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}

View File

@@ -3,6 +3,8 @@ package common
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
@@ -36,7 +38,11 @@ func init() {
} }
if os.Getenv("SESSION_SECRET") != "" { if os.Getenv("SESSION_SECRET") != "" {
SessionSecret = os.Getenv("SESSION_SECRET") if os.Getenv("SESSION_SECRET") == "random_string" {
logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
config.SessionSecret = os.Getenv("SESSION_SECRET")
}
} }
if os.Getenv("SQLITE_PATH") != "" { if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH") SQLitePath = os.Getenv("SQLITE_PATH")
@@ -53,5 +59,6 @@ func init() {
log.Fatal(err) log.Fatal(err)
} }
} }
logger.LogDir = *LogDir
} }
} }

View File

@@ -0,0 +1,7 @@
package logger
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
var LogDir string

View File

@@ -1,9 +1,11 @@
package common package logger
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"io" "io"
"log" "log"
"os" "os"
@@ -13,19 +15,17 @@ import (
) )
const ( const (
loggerDEBUG = "DEBUG"
loggerINFO = "INFO" loggerINFO = "INFO"
loggerWarn = "WARN" loggerWarn = "WARN"
loggerError = "ERR" loggerError = "ERR"
) )
const maxLogCount = 1000000
var logCount int
var setupLogLock sync.Mutex var setupLogLock sync.Mutex
var setupLogWorking bool var setupLogWorking bool
func SetupLogger() { func SetupLogger() {
if *LogDir != "" { if LogDir != "" {
ok := setupLogLock.TryLock() ok := setupLogLock.TryLock()
if !ok { if !ok {
log.Println("setup log is already working") log.Println("setup log is already working")
@@ -35,7 +35,7 @@ func SetupLogger() {
setupLogLock.Unlock() setupLogLock.Unlock()
setupLogWorking = false setupLogWorking = false
}() }()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
log.Fatal("failed to open log file") log.Fatal("failed to open log file")
@@ -55,29 +55,52 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func LogInfo(ctx context.Context, msg string) { func Debug(ctx context.Context, msg string) {
if config.DebugEnabled {
logHelper(ctx, loggerDEBUG, msg)
}
}
func Info(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg) logHelper(ctx, loggerINFO, msg)
} }
func LogWarn(ctx context.Context, msg string) { func Warn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg) logHelper(ctx, loggerWarn, msg)
} }
func LogError(ctx context.Context, msg string) { func Error(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg) logHelper(ctx, loggerError, msg)
} }
func Debugf(ctx context.Context, format string, a ...any) {
Debug(ctx, fmt.Sprintf(format, a...))
}
func Infof(ctx context.Context, format string, a ...any) {
Info(ctx, fmt.Sprintf(format, a...))
}
func Warnf(ctx context.Context, format string, a ...any) {
Warn(ctx, fmt.Sprintf(format, a...))
}
func Errorf(ctx context.Context, format string, a ...any) {
Error(ctx, fmt.Sprintf(format, a...))
}
func logHelper(ctx context.Context, level string, msg string) { func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter writer := gin.DefaultErrorWriter
if level == loggerINFO { if level == loggerINFO {
writer = gin.DefaultWriter writer = gin.DefaultWriter
} }
id := ctx.Value(RequestIdKey) id := ctx.Value(RequestIdKey)
if id == nil {
id = helper.GenRequestID()
}
now := time.Now() now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here if !setupLogWorking {
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true setupLogWorking = true
go func() { go func() {
SetupLogger() SetupLogger()
@@ -90,11 +113,3 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1) os.Exit(1)
} }
func LogQuota(quota int) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", quota)
}
}

90
common/message/email.go Normal file
View File

@@ -0,0 +1,90 @@
package message
import (
"crypto/rand"
"crypto/tls"
"encoding/base64"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"net/smtp"
"strings"
"time"
)
func SendEmail(subject string, receiver string, content string) error {
if receiver == "" {
return fmt.Errorf("receiver is empty")
}
if config.SMTPFrom == "" { // for compatibility
config.SMTPFrom = config.SMTPAccount
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
parts := strings.Split(config.SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
}
// Generate a unique Message-ID
buf := make([]byte, 16)
_, err := rand.Read(buf)
if err != nil {
return err
}
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";")
if config.SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: config.SMTPServer,
}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
if err != nil {
return err
}
client, err := smtp.NewClient(conn, config.SMTPServer)
if err != nil {
return err
}
defer client.Close()
if err = client.Auth(auth); err != nil {
return err
}
if err = client.Mail(config.SMTPFrom); err != nil {
return err
}
receiverEmails := strings.Split(receiver, ";")
for _, receiver := range receiverEmails {
if err = client.Rcpt(receiver); err != nil {
return err
}
}
w, err := client.Data()
if err != nil {
return err
}
_, err = w.Write(mail)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
} else {
err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail)
}
return err
}

22
common/message/main.go Normal file
View File

@@ -0,0 +1,22 @@
package message
import (
"fmt"
"github.com/songquanpeng/one-api/common/config"
)
const (
ByAll = "all"
ByEmail = "email"
ByMessagePusher = "message_pusher"
)
func Notify(by string, title string, description string, content string) error {
if by == ByEmail {
return SendEmail(title, config.RootUserEmail, content)
}
if by == ByMessagePusher {
return SendMessage(title, description, content)
}
return fmt.Errorf("unknown notify method: %s", by)
}

View File

@@ -0,0 +1,53 @@
package message
import (
"bytes"
"encoding/json"
"errors"
"github.com/songquanpeng/one-api/common/config"
"net/http"
)
type request struct {
Title string `json:"title"`
Description string `json:"description"`
Content string `json:"content"`
URL string `json:"url"`
Channel string `json:"channel"`
Token string `json:"token"`
}
type response struct {
Success bool `json:"success"`
Message string `json:"message"`
}
func SendMessage(title string, description string, content string) error {
if config.MessagePusherAddress == "" {
return errors.New("message pusher address is not set")
}
req := request{
Title: title,
Description: description,
Content: content,
Token: config.MessagePusherToken,
}
data, err := json.Marshal(req)
if err != nil {
return err
}
resp, err := http.Post(config.MessagePusherAddress,
"application/json", bytes.NewBuffer(data))
if err != nil {
return err
}
var res response
err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil {
return err
}
if !res.Success {
return errors.New(res.Message)
}
return nil
}

View File

@@ -2,29 +2,44 @@ package common
import ( import (
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/logger"
"strings" "strings"
) )
const (
USD2RMB = 7
USD = 500 // $0.002 = 1 -> $1 = 500
RMB = USD / USD2RMB
)
// ModelRatio // ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
// https://openai.com/pricing // https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens // 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens // 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{ var ModelRatio = map[string]float64{
// https://openai.com/pricing
"gpt-4": 15, "gpt-4": 15,
"gpt-4-0314": 15, "gpt-4-0314": 15,
"gpt-4-0613": 15, "gpt-4-0613": 15,
"gpt-4-32k": 30, "gpt-4-32k": 30,
"gpt-4-32k-0314": 30, "gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30, "gpt-4-32k-0613": 30,
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-3.5-turbo-0301": 0.75, "gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5, "gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens
"text-ada-001": 0.2, "text-ada-001": 0.2,
"text-babbage-001": 0.25, "text-babbage-001": 0.25,
"text-curie-001": 1, "text-curie-001": 1,
@@ -33,40 +48,140 @@ var ModelRatio = map[string]float64{
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"davinci": 10, "davinci": 10,
"curie": 10, "curie": 10,
"babbage": 10, "babbage": 10,
"ada": 10, "ada": 10,
"text-embedding-ada-002": 0.05, "text-embedding-ada-002": 0.05,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-search-ada-doc-001": 10, "text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1, "text-moderation-stable": 0.1,
"text-moderation-latest": 0.1, "text-moderation-latest": 0.1,
"dall-e": 8, "dall-e-2": 8, // $0.016 - $0.020 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens "dall-e-3": 20, // $0.040 - $0.120 / image
"claude-2": 5.51, // $11.02 / 1M tokens // https://www.anthropic.com/api#pricing
"claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD,
"claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240307": 0.25 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
"ERNIE-Bot-8K": 0.024 * RMB,
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"bge-large-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
"PaLM-2": 1, "PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
// https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
"glm-3-turbo": 0.005 * RMB,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-plus": 10, // ¥0.14 / 1k tokens "qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens "SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"ChatStd": 0.01 * RMB,
"ChatPro": 0.1 * RMB,
// https://platform.moonshot.cn/pricing
"moonshot-v1-8k": 0.012 * RMB,
"moonshot-v1-32k": 0.024 * RMB,
"moonshot-v1-128k": 0.06 * RMB,
// https://platform.baichuan-ai.com/price
"Baichuan2-Turbo": 0.008 * RMB,
"Baichuan2-Turbo-192k": 0.016 * RMB,
"Baichuan2-53B": 0.02 * RMB,
// https://api.minimax.chat/document/price
"abab6-chat": 0.1 * RMB,
"abab5.5-chat": 0.015 * RMB,
"abab5.5s-chat": 0.005 * RMB,
// https://docs.mistral.ai/platform/pricing/
"open-mistral-7b": 0.25 / 1000 * USD,
"open-mixtral-8x7b": 0.7 / 1000 * USD,
"mistral-small-latest": 2.0 / 1000 * USD,
"mistral-medium-latest": 2.7 / 1000 * USD,
"mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD,
// https://wow.groq.com/
"llama2-70b-4096": 0.7 / 1000 * USD,
"llama2-7b-2048": 0.1 / 1000 * USD,
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD,
// https://platform.lingyiwanwu.com/docs#-计费单元
"yi-34b-chat-0205": 2.5 / 1000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000 * RMB,
"yi-vl-plus": 6.0 / 1000 * RMB,
}
var CompletionRatio = map[string]float64{}
var DefaultModelRatio map[string]float64
var DefaultCompletionRatio map[string]float64
func init() {
DefaultModelRatio = make(map[string]float64)
for k, v := range ModelRatio {
DefaultModelRatio[k] = v
}
DefaultCompletionRatio = make(map[string]float64)
for k, v := range CompletionRatio {
DefaultCompletionRatio[k] = v
}
}
func AddNewMissingRatio(oldRatio string) string {
newRatio := make(map[string]float64)
err := json.Unmarshal([]byte(oldRatio), &newRatio)
if err != nil {
logger.SysError("error unmarshalling old ratio: " + err.Error())
return oldRatio
}
for k, v := range DefaultModelRatio {
if _, ok := newRatio[k]; !ok {
newRatio[k] = v
}
}
jsonBytes, err := json.Marshal(newRatio)
if err != nil {
logger.SysError("error marshalling new ratio: " + err.Error())
return oldRatio
}
return string(jsonBytes)
} }
func ModelRatio2JSONString() string { func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio) jsonBytes, err := json.Marshal(ModelRatio)
if err != nil { if err != nil {
SysError("error marshalling model ratio: " + err.Error()) logger.SysError("error marshalling model ratio: " + err.Error())
} }
return string(jsonBytes) return string(jsonBytes)
} }
@@ -77,26 +192,72 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
} }
func GetModelRatio(name string) float64 { func GetModelRatio(name string) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
ratio, ok := ModelRatio[name] ratio, ok := ModelRatio[name]
if !ok { if !ok {
SysError("model ratio not found: " + name) ratio, ok = DefaultModelRatio[name]
}
if !ok {
logger.SysError("model ratio not found: " + name)
return 30 return 30
} }
return ratio return ratio
} }
func GetCompletionRatio(name string) float64 { func CompletionRatio2JSONString() string {
if strings.HasPrefix(name, "gpt-3.5") { jsonBytes, err := json.Marshal(CompletionRatio)
return 1.333333 if err != nil {
logger.SysError("error marshalling completion ratio: " + err.Error())
} }
if strings.HasPrefix(name, "gpt-4") { return string(jsonBytes)
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}
if ratio, ok := DefaultCompletionRatio[name]; ok {
return ratio
}
if strings.HasPrefix(name, "gpt-3.5") {
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
return 3
}
if strings.HasSuffix(name, "1106") {
return 2 return 2
} }
if strings.HasPrefix(name, "claude-instant-1") { return 4.0 / 3.0
return 3.38
} }
if strings.HasPrefix(name, "claude-2") { if strings.HasPrefix(name, "gpt-4") {
return 2.965517 if strings.HasSuffix(name, "preview") {
return 3
}
return 2
}
if strings.HasPrefix(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "claude-") {
return 3
}
if strings.HasPrefix(name, "mistral-") {
return 3
}
if strings.HasPrefix(name, "gemini-") {
return 3
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7
} }
return 1 return 1
} }

8
common/random.go Normal file
View File

@@ -0,0 +1,8 @@
package common
import "math/rand"
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

View File

@@ -3,6 +3,7 @@ package common
import ( import (
"context" "context"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/songquanpeng/one-api/common/logger"
"os" "os"
"time" "time"
) )
@@ -14,18 +15,18 @@ var RedisEnabled = true
func InitRedisClient() (err error) { func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" { if os.Getenv("REDIS_CONN_STRING") == "" {
RedisEnabled = false RedisEnabled = false
SysLog("REDIS_CONN_STRING not set, Redis is not enabled") logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
return nil return nil
} }
if os.Getenv("SYNC_FREQUENCY") == "" { if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false RedisEnabled = false
SysLog("SYNC_FREQUENCY not set, Redis is disabled") logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil return nil
} }
SysLog("Redis is enabled") logger.SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil { if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error()) logger.FatalLog("failed to parse Redis connection string: " + err.Error())
} }
RDB = redis.NewClient(opt) RDB = redis.NewClient(opt)
@@ -34,7 +35,7 @@ func InitRedisClient() (err error) {
_, err = RDB.Ping(ctx).Result() _, err = RDB.Ping(ctx).Result()
if err != nil { if err != nil {
FatalLog("Redis ping test failed: " + err.Error()) logger.FatalLog("Redis ping test failed: " + err.Error())
} }
return err return err
} }
@@ -42,7 +43,7 @@ func InitRedisClient() (err error) {
func ParseRedisOption() *redis.Options { func ParseRedisOption() *redis.Options {
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil { if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error()) logger.FatalLog("failed to parse Redis connection string: " + err.Error())
} }
return opt return opt
} }

View File

@@ -2,208 +2,13 @@ package common
import ( import (
"fmt" "fmt"
"github.com/google/uuid" "github.com/songquanpeng/one-api/common/config"
"html/template"
"log"
"math/rand"
"net"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
) )
func OpenBrowser(url string) { func LogQuota(quota int64) string {
var err error if config.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/config.QuotaPerUnit)
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
} else { } else {
numStr = fmt.Sprintf("%d", num) return fmt.Sprintf("%d 点额度", quota)
}
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter.(type) {
case string:
return inter.(string)
case int:
return fmt.Sprintf("%d", inter.(int))
case float64:
return fmt.Sprintf("%f", inter.(float64))
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
} }
} }
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@@ -2,17 +2,18 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "github.com/songquanpeng/one-api/common/config"
"one-api/model" "github.com/songquanpeng/one-api/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
func GetSubscription(c *gin.Context) { func GetSubscription(c *gin.Context) {
var remainQuota int var remainQuota int64
var usedQuota int var usedQuota int64
var err error var err error
var token *model.Token var token *model.Token
var expiredTime int64 var expiredTime int64
if common.DisplayTokenStatEnabled { if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime expiredTime = token.ExpiredTime
@@ -21,25 +22,27 @@ func GetSubscription(c *gin.Context) {
} else { } else {
userId := c.GetInt("id") userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId) remainQuota, err = model.GetUserQuota(userId)
if err != nil {
usedQuota, err = model.GetUserUsedQuota(userId) usedQuota, err = model.GetUserUsedQuota(userId)
} }
}
if expiredTime <= 0 { if expiredTime <= 0 {
expiredTime = 0 expiredTime = 0
} }
if err != nil { if err != nil {
openAIError := OpenAIError{ Error := relaymodel.Error{
Message: err.Error(), Message: err.Error(),
Type: "upstream_error", Type: "upstream_error",
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"error": openAIError, "error": Error,
}) })
return return
} }
quota := remainQuota + usedQuota quota := remainQuota + usedQuota
amount := float64(quota) amount := float64(quota)
if common.DisplayInCurrencyEnabled { if config.DisplayInCurrencyEnabled {
amount /= common.QuotaPerUnit amount /= config.QuotaPerUnit
} }
if token != nil && token.UnlimitedQuota { if token != nil && token.UnlimitedQuota {
amount = 100000000 amount = 100000000
@@ -57,10 +60,10 @@ func GetSubscription(c *gin.Context) {
} }
func GetUsage(c *gin.Context) { func GetUsage(c *gin.Context) {
var quota int var quota int64
var err error var err error
var token *model.Token var token *model.Token
if common.DisplayTokenStatEnabled { if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota quota = token.UsedQuota
@@ -69,18 +72,18 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId) quota, err = model.GetUserUsedQuota(userId)
} }
if err != nil { if err != nil {
openAIError := OpenAIError{ Error := relaymodel.Error{
Message: err.Error(), Message: err.Error(),
Type: "one_api_error", Type: "one_api_error",
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"error": openAIError, "error": Error,
}) })
return return
} }
amount := float64(quota) amount := float64(quota)
if common.DisplayInCurrencyEnabled { if config.DisplayInCurrencyEnabled {
amount /= common.QuotaPerUnit amount /= config.QuotaPerUnit
} }
usage := OpenAIUsageResponse{ usage := OpenAIUsageResponse{
Object: "list", Object: "list",

View File

@@ -4,10 +4,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"time" "time"
@@ -92,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers { for k := range headers {
req.Header.Add(k, headers.Get(k)) req.Header.Add(k, headers.Get(k))
} }
res, err := httpClient.Do(req) res, err := util.HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -292,7 +296,7 @@ func UpdateChannelBalance(c *gin.Context) {
} }
func updateAllChannelsBalance() error { func updateAllChannelsBalance() error {
channels, err := model.GetAllChannels(0, 0, true) channels, err := model.GetAllChannels(0, 0, "all")
if err != nil { if err != nil {
return err return err
} }
@@ -310,24 +314,23 @@ func updateAllChannelsBalance() error {
} else { } else {
// err is nil & balance <= 0 means quota is used up // err is nil & balance <= 0 means quota is used up
if balance <= 0 { if balance <= 0 {
disableChannel(channel.Id, channel.Name, "余额不足") monitor.DisableChannel(channel.Id, channel.Name, "余额不足")
} }
} }
time.Sleep(common.RequestInterval) time.Sleep(config.RequestInterval)
} }
return nil return nil
} }
func UpdateAllChannelsBalance(c *gin.Context) { func UpdateAllChannelsBalance(c *gin.Context) {
// TODO: make it async //err := updateAllChannelsBalance()
err := updateAllChannelsBalance() //if err != nil {
if err != nil { // c.JSON(http.StatusOK, gin.H{
c.JSON(http.StatusOK, gin.H{ // "success": false,
"success": false, // "message": err.Error(),
"message": err.Error(), // })
}) // return
return //}
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@@ -338,8 +341,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
func AutomaticallyUpdateChannels(frequency int) { func AutomaticallyUpdateChannels(frequency int) {
for { for {
time.Sleep(time.Duration(frequency) * time.Minute) time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("updating all channels") logger.SysLog("updating all channels")
_ = updateAllChannelsBalance() _ = updateAllChannelsBalance()
common.SysLog("channels update done") logger.SysLog("channels update done")
} }
} }

View File

@@ -5,88 +5,36 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http" "net/http"
"one-api/common" "net/http/httptest"
"one-api/model" "net/url"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
switch channel.Type { testRequest := &relaymodel.GeneralOpenAIRequest{
case common.ChannelTypePaLM: MaxTokens: 2,
fallthrough Stream: false,
case common.ChannelTypeAnthropic: Model: "gpt-3.5-turbo",
fallthrough
case common.ChannelTypeBaidu:
fallthrough
case common.ChannelTypeZhipu:
fallthrough
case common.ChannelTypeAli:
fallthrough
case common.ChannelType360:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo"
defer func() {
if err != nil {
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
} }
}() testMessage := relaymodel.Message{
default:
request.Model = "gpt-3.5-turbo"
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
} else {
if channel.GetBaseURL() != "" {
requestURL = channel.GetBaseURL()
}
requestURL += "/v1/chat/completions"
}
jsonData, err := json.Marshal(request)
if err != nil {
return err, nil
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err, nil
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return err, nil
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err, nil
}
if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
return nil, nil
}
func buildTestRequest() *ChatRequest {
testRequest := &ChatRequest{
Model: "", // this will be set later
MaxTokens: 1,
}
testMessage := Message{
Role: "user", Role: "user",
Content: "hi", Content: "hi",
} }
@@ -94,6 +42,72 @@ func buildTestRequest() *ChatRequest {
return testRequest return testRequest
} }
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
Body: nil,
Header: make(http.Header),
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, "")
meta := util.GetRelayMeta(c)
apiType := constant.ChannelType2APIType(channel.Type)
adaptor := helper.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
adaptor.Init(meta)
modelName := adaptor.GetModelList()[0]
if !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 {
modelName = modelNames[0]
}
}
request := buildTestRequest()
request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
if err != nil {
return err, nil
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return err, nil
}
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
return err, nil
}
if resp.StatusCode != http.StatusOK {
err := util.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
}
if usage == nil {
return errors.New("usage is nil"), nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
}
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
func TestChannel(c *gin.Context) { func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
@@ -111,9 +125,8 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
testRequest := buildTestRequest()
tik := time.Now() tik := time.Now()
err, _ = testChannel(channel, *testRequest) err, _ = testChannel(channel)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
@@ -137,23 +150,9 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false var testAllChannelsRunning bool = false
// disable & notify func testChannels(notify bool, scope string) error {
func disableChannel(channelId int, channelName string, reason string) { if config.RootUserEmail == "" {
if common.RootUserEmail == "" { config.RootUserEmail = model.GetRootUserEmail()
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
} }
testAllChannelsLock.Lock() testAllChannelsLock.Lock()
if testAllChannelsRunning { if testAllChannelsRunning {
@@ -162,49 +161,57 @@ func testAllChannels(notify bool) error {
} }
testAllChannelsRunning = true testAllChannelsRunning = true
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true) channels, err := model.GetAllChannels(0, 0, scope)
if err != nil { if err != nil {
return err return err
} }
testRequest := buildTestRequest() var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 { if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value disableThreshold = 10000000 // a impossible value
} }
go func() { go func() {
for _, channel := range channels { for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled { isChannelEnabled := channel.Status == common.ChannelStatusEnabled
continue
}
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel, *testRequest) err, openaiErr := testChannel(channel)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if milliseconds > disableThreshold { if isChannelEnabled && milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
disableChannel(channel.Id, channel.Name, err.Error()) if config.AutomaticDisableChannelEnabled {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} else {
_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s %d测试超时", channel.Name, channel.Id), "", err.Error())
} }
if shouldDisableChannel(openaiErr, -1) { }
disableChannel(channel.Id, channel.Name, err.Error()) if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
monitor.EnableChannel(channel.Id, channel.Name)
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval) time.Sleep(config.RequestInterval)
} }
testAllChannelsLock.Lock() testAllChannelsLock.Lock()
testAllChannelsRunning = false testAllChannelsRunning = false
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
if notify { if notify {
err := common.SendEmail("道测试完成", common.RootUserEmail, "道测试完成,如果没有收到禁用通知,说明所有道都正常") err := message.Notify(message.ByAll, "道测试完成", "", "道测试完成,如果没有收到禁用通知,说明所有道都正常")
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
} }
} }
}() }()
return nil return nil
} }
func TestAllChannels(c *gin.Context) { func TestChannels(c *gin.Context) {
err := testAllChannels(true) scope := c.Query("scope")
if scope == "" {
scope = "all"
}
err := testChannels(true, scope)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -222,8 +229,8 @@ func TestAllChannels(c *gin.Context) {
func AutomaticallyTestChannels(frequency int) { func AutomaticallyTestChannels(frequency int) {
for { for {
time.Sleep(time.Duration(frequency) * time.Minute) time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("testing all channels") logger.SysLog("testing all channels")
_ = testAllChannels(false) _ = testChannels(false, "all")
common.SysLog("channel test finished") logger.SysLog("channel test finished")
} }
} }

View File

@@ -2,9 +2,10 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"strings" "strings"
) )
@@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false) channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited")
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) {
}) })
return return
} }
channel.CreatedTime = common.GetTimestamp() channel.CreatedTime = helper.GetTimestamp()
keys := strings.Split(channel.Key, "\n") keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0, len(keys)) channels := make([]model.Channel, 0, len(keys))
for _, key := range keys { for _, key := range keys {

View File

@@ -7,9 +7,12 @@ import (
"fmt" "fmt"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"time" "time"
) )
@@ -30,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" { if code == "" {
return nil, errors.New("无效的参数") return nil, errors.New("无效的参数")
} }
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values) jsonData, err := json.Marshal(values)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -46,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
} }
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
common.SysLog(err.Error()) logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
} }
defer res.Body.Close() defer res.Body.Close()
@@ -62,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req) res2, err := client.Do(req)
if err != nil { if err != nil {
common.SysLog(err.Error()) logger.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
} }
defer res2.Body.Close() defer res2.Body.Close()
@@ -93,7 +96,7 @@ func GitHubOAuth(c *gin.Context) {
return return
} }
if !common.GitHubOAuthEnabled { if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "管理员未开启通过 GitHub 登录以及注册", "message": "管理员未开启通过 GitHub 登录以及注册",
@@ -122,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
return return
} }
} else { } else {
if common.RegisterEnabled { if config.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" { if githubUser.Name != "" {
user.DisplayName = githubUser.Name user.DisplayName = githubUser.Name
@@ -160,7 +163,7 @@ func GitHubOAuth(c *gin.Context) {
} }
func GitHubBind(c *gin.Context) { func GitHubBind(c *gin.Context) {
if !common.GitHubOAuthEnabled { if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "管理员未开启通过 GitHub 登录以及注册", "message": "管理员未开启通过 GitHub 登录以及注册",
@@ -216,7 +219,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
state := common.GetRandomString(12) state := helper.GetRandomString(12)
session.Set("oauth_state", state) session.Set("oauth_state", state)
err := session.Save() err := session.Save()
if err != nil { if err != nil {

View File

@@ -2,13 +2,13 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"net/http" "net/http"
"one-api/common"
) )
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
groupNames := make([]string, 0) groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio { for groupName := range common.GroupRatio {
groupNames = append(groupNames, groupName) groupNames = append(groupNames, groupName)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -2,9 +2,9 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
) )
@@ -20,7 +20,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -47,7 +47,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@@ -3,9 +3,11 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -18,55 +20,55 @@ func GetStatus(c *gin.Context) {
"data": gin.H{ "data": gin.H{
"version": common.Version, "version": common.Version,
"start_time": common.StartTime, "start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled, "email_verification": config.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled, "github_oauth": config.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId, "github_client_id": config.GitHubClientId,
"system_name": common.SystemName, "system_name": config.SystemName,
"logo": common.Logo, "logo": config.Logo,
"footer_html": common.Footer, "footer_html": config.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_qrcode": config.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled, "wechat_login": config.WeChatAuthEnabled,
"server_address": common.ServerAddress, "server_address": config.ServerAddress,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": config.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey, "turnstile_site_key": config.TurnstileSiteKey,
"top_up_link": common.TopUpLink, "top_up_link": config.TopUpLink,
"chat_link": common.ChatLink, "chat_link": config.ChatLink,
"quota_per_unit": common.QuotaPerUnit, "quota_per_unit": config.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled, "display_in_currency": config.DisplayInCurrencyEnabled,
}, },
}) })
return return
} }
func GetNotice(c *gin.Context) { func GetNotice(c *gin.Context) {
common.OptionMapRWMutex.RLock() config.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock() defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": common.OptionMap["Notice"], "data": config.OptionMap["Notice"],
}) })
return return
} }
func GetAbout(c *gin.Context) { func GetAbout(c *gin.Context) {
common.OptionMapRWMutex.RLock() config.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock() defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": common.OptionMap["About"], "data": config.OptionMap["About"],
}) })
return return
} }
func GetHomePageContent(c *gin.Context) { func GetHomePageContent(c *gin.Context) {
common.OptionMapRWMutex.RLock() config.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock() defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": common.OptionMap["HomePageContent"], "data": config.OptionMap["HomePageContent"],
}) })
return return
} }
@@ -80,9 +82,9 @@ func SendEmailVerification(c *gin.Context) {
}) })
return return
} }
if common.EmailDomainRestrictionEnabled { if config.EmailDomainRestrictionEnabled {
allowed := false allowed := false
for _, domain := range common.EmailDomainWhitelist { for _, domain := range config.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) { if strings.HasSuffix(email, "@"+domain) {
allowed = true allowed = true
break break
@@ -105,11 +107,11 @@ func SendEmailVerification(c *gin.Context) {
} }
code := common.GenerateVerificationCode(6) code := common.GenerateVerificationCode(6)
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+
"<p>您的验证码为: <strong>%s</strong></p>"+ "<p>您的验证码为: <strong>%s</strong></p>"+
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes) "<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content) err := message.SendEmail(subject, email, content)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -142,13 +144,13 @@ func SendPasswordResetEmail(c *gin.Context) {
} }
code := common.GenerateVerificationCode(0) code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName) subject := fmt.Sprintf("%s密码重置", config.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ "<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ "<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes) "<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content) err := message.SendEmail(subject, email, content)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@@ -2,8 +2,14 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http"
) )
// https://platform.openai.com/docs/api-reference/models/list // https://platform.openai.com/docs/api-reference/models/list
@@ -35,6 +41,7 @@ type OpenAIModels struct {
var openAIModels []OpenAIModels var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels var openAIModelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func init() { func init() {
var permission []OpenAIModelPermission var permission []OpenAIModelPermission
@@ -53,399 +60,63 @@ func init() {
IsBlocking: false, IsBlocking: false,
}) })
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{ for i := 0; i < constant.APITypeDummy; i++ {
{ if i == constant.APITypeAIProxyLibrary {
Id: "dall-e", continue
}
adaptor := helper.GetAdaptor(i)
channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model", Object: "model",
Created: 1677649963, Created: 1626777600,
OwnedBy: "openai", OwnedBy: channelName,
Permission: permission, Permission: permission,
Root: "dall-e", Root: modelName,
Parent: nil, Parent: nil,
}, })
{ }
Id: "whisper-1", }
for _, channelType := range openai.CompatibleChannels {
if channelType == common.ChannelTypeAzure {
continue
}
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model", Object: "model",
Created: 1677649963, Created: 1626777600,
OwnedBy: "openai", OwnedBy: channelName,
Permission: permission, Permission: permission,
Root: "whisper-1", Root: modelName,
Parent: nil, Parent: nil,
}, })
{ }
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0301",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0301",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-instruct",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-instruct",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4",
Parent: nil,
},
{
Id: "gpt-4-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0314",
Parent: nil,
},
{
Id: "gpt-4-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0613",
Parent: nil,
},
{
Id: "gpt-4-32k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k",
Parent: nil,
},
{
Id: "gpt-4-32k-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-4-32k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0613",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
{
Id: "text-davinci-003",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-003",
Parent: nil,
},
{
Id: "text-davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-002",
Parent: nil,
},
{
Id: "text-curie-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-curie-001",
Parent: nil,
},
{
Id: "text-babbage-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-babbage-001",
Parent: nil,
},
{
Id: "text-ada-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-ada-001",
Parent: nil,
},
{
Id: "text-moderation-latest",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-latest",
Parent: nil,
},
{
Id: "text-moderation-stable",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-stable",
Parent: nil,
},
{
Id: "text-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-edit-001",
Parent: nil,
},
{
Id: "code-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "code-davinci-edit-001",
Parent: nil,
},
{
Id: "claude-instant-1",
Object: "model",
Created: 1677649963,
OwnedBy: "anturopic",
Permission: permission,
Root: "claude-instant-1",
Parent: nil,
},
{
Id: "claude-2",
Object: "model",
Created: 1677649963,
OwnedBy: "anturopic",
Permission: permission,
Root: "claude-2",
Parent: nil,
},
{
Id: "ERNIE-Bot",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot",
Parent: nil,
},
{
Id: "ERNIE-Bot-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "ERNIE-Bot-4",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-4",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "Embedding-V1",
Parent: nil,
},
{
Id: "PaLM-2",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "PaLM-2",
Parent: nil,
},
{
Id: "chatglm_pro",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_pro",
Parent: nil,
},
{
Id: "chatglm_std",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_std",
Parent: nil,
},
{
Id: "chatglm_lite",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "qwen-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-turbo",
Parent: nil,
},
{
Id: "qwen-plus",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus",
Parent: nil,
},
{
Id: "text-embedding-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "text-embedding-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",
Created: 1677649963,
OwnedBy: "xunfei",
Permission: permission,
Root: "SparkDesk",
Parent: nil,
},
{
Id: "360GPT_S2_V9",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "360GPT_S2_V9",
Parent: nil,
},
{
Id: "embedding-bert-512-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding-bert-512-v1",
Parent: nil,
},
{
Id: "embedding_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding_s1_v1",
Parent: nil,
},
{
Id: "semantic_similarity_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "semantic_similarity_s1_v1",
Parent: nil,
},
{
Id: "hunyuan",
Object: "model",
Created: 1677649963,
OwnedBy: "tencent",
Permission: permission,
Root: "hunyuan",
Parent: nil,
},
} }
openAIModelsMap = make(map[string]OpenAIModels) openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels { for _, model := range openAIModels {
openAIModelsMap[model.Id] = model openAIModelsMap[model.Id] = model
} }
channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ {
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i))
meta := &util.RelayMeta{
ChannelType: i,
}
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
}
func DashboardListModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channelId2Models,
})
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {
@@ -460,14 +131,14 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok { if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model) c.JSON(200, model)
} else { } else {
openAIError := OpenAIError{ Error := relaymodel.Error{
Message: fmt.Sprintf("The model '%s' does not exist", modelId), Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error", Type: "invalid_request_error",
Param: "model", Param: "model",
Code: "model_not_found", Code: "model_not_found",
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"error": openAIError, "error": Error,
}) })
} }
} }

View File

@@ -2,9 +2,10 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -12,17 +13,17 @@ import (
func GetOptions(c *gin.Context) { func GetOptions(c *gin.Context) {
var options []*model.Option var options []*model.Option
common.OptionMapRWMutex.Lock() config.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap { for k, v := range config.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue continue
} }
options = append(options, &model.Option{ options = append(options, &model.Option{
Key: k, Key: k,
Value: common.Interface2String(v), Value: helper.Interface2String(v),
}) })
} }
common.OptionMapRWMutex.Unlock() config.OptionMapRWMutex.Unlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@@ -42,8 +43,16 @@ func UpdateOption(c *gin.Context) {
return return
} }
switch option.Key { switch option.Key {
case "Theme":
if !config.ValidThemes[option.Value] {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的主题",
})
return
}
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
if option.Value == "true" && common.GitHubClientId == "" { if option.Value == "true" && config.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret", "message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret",
@@ -51,7 +60,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
@@ -59,7 +68,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
if option.Value == "true" && common.WeChatServerAddress == "" { if option.Value == "true" && config.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用微信登录,请先填入微信登录相关配置信息!", "message": "无法启用微信登录,请先填入微信登录相关配置信息!",
@@ -67,7 +76,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "TurnstileCheckEnabled": case "TurnstileCheckEnabled":
if option.Value == "true" && common.TurnstileSiteKey == "" { if option.Value == "true" && config.TurnstileSiteKey == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",

View File

@@ -2,9 +2,10 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
) )
@@ -13,7 +14,7 @@ func GetAllRedemptions(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage) redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -105,12 +106,12 @@ func AddRedemption(c *gin.Context) {
} }
var keys []string var keys []string
for i := 0; i < redemption.Count; i++ { for i := 0; i < redemption.Count; i++ {
key := common.GetUUID() key := helper.GetUUID()
cleanRedemption := model.Redemption{ cleanRedemption := model.Redemption{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: redemption.Name, Name: redemption.Name,
Key: key, Key: key,
CreatedTime: common.GetTimestamp(), CreatedTime: helper.GetTimestamp(),
Quota: redemption.Quota, Quota: redemption.Quota,
} }
err = cleanRedemption.Insert() err = cleanRedemption.Insert()

View File

@@ -1,329 +0,0 @@
package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.Content,
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.Content
break
}
messages = append(messages, AliMessage{
User: message.Content,
Bot: request.Messages[i+1].Content,
})
i++
}
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
Prompt: prompt,
History: messages,
},
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP,
// TopK: 50,
// //Seed: 0,
// //EnableSearch: false,
//},
}
}
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -1,151 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
)
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
audioModel := "whisper-1"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioModel] != "" {
audioModel = modelMap[audioModel]
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var audioResponse AudioResponse
defer func(ctx context.Context) {
go func() {
quota := countTokenText(audioResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -1,220 +0,0 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
}
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return reason
}
}
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse)
response.Id = responseId
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
completionTokens := countTokenText(claudeResponse.Completion, model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -1,177 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
)
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
imageModel := "dall-e"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var imageRequest ImageRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
}
// Prompt validation
if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
}
// N should between 1 and 10
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 1
} else if imageRequest.Size == "512x512" {
sizeRatio = 1.125
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25
}
quota := int(ratio*sizeRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var textResponse ImageResponse
defer func(ctx context.Context) {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -1,287 +0,0 @@
package controller
import (
"bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"sort"
"strconv"
"strings"
)
// https://cloud.tencent.com/document/product/1729/97732
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.Content,
})
messages = append(messages, TencentMessage{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, TencentMessage{
Content: message.Content,
Role: message.Role,
})
}
stream := 0
if request.Stream {
stream = 1
}
return &TencentChatRequest{
Timestamp: common.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages,
}
}
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: response.Usage,
}
if len(response.Choices) > 0 {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Choices[0].Messages.Content,
},
FinishReason: response.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &stopFinishReason
}
response.Choices = append(response.Choices, choice)
}
return &response
}
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var TencentResponse TencentChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &TencentResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
parts := strings.Split(config, "|")
if len(parts) != 3 {
err = errors.New("invalid tencent config")
return
}
appId, err = strconv.ParseInt(parts[0], 10, 64)
secretId = parts[1]
secretKey = parts[2]
return
}
func getTencentSign(req TencentChatRequest, secretKey string) string {
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
}

View File

@@ -1,646 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const (
APITypeOpenAI = iota
APITypeClaude
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
)
var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
}
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var textRequest GeneralOpenAIRequest
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
switch relayMode {
case RelayModeCompletions:
if textRequest.Prompt == "" {
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEmbeddings:
case RelayModeModerations:
if textRequest.Input == "" {
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEdits:
if textRequest.Instruction == "" {
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
isModelMapped = true
}
}
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeClaude
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
requestURL := strings.Split(requestURL, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
baseURL = c.GetString("base_url")
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := textRequest.Model
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
}
case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
}
case APITypeBaidu:
switch textRequest.Model {
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
var err error
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
}
fullRequestURL += "?access_token=" + apiKey
case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
case APITypeTencent:
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
}
var promptTokens int
var completionTokens int
switch relayMode {
case RelayModeChatCompletions:
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if consumeQuota && preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
switch apiType {
case APITypeClaude:
claudeRequest := requestOpenAI2Claude(textRequest)
jsonStr, err := json.Marshal(claudeRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu:
var jsonData []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduEmbeddingRequest)
default:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
var jsonStr []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliEmbeddingRequest)
default:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeTencent:
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
if err != nil {
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
}
tencentRequest := requestOpenAI2Tencent(textRequest)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
jsonStr, err := json.Marshal(tencentRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
sign := getTencentSign(*tencentRequest, secretKey)
c.Request.Header.Set("Authorization", sign)
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAIProxyLibrary:
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
var req *http.Request
var resp *http.Response
isStream := textRequest.Stream
if apiType != APITypeXunfei { // cause xunfei use websocket
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
if channelType == common.ChannelTypeOpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
}
case APITypeClaude:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
}
var textResponse TextResponse
tokenName := c.GetString("token_name")
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
if consumeQuota {
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
}(c.Request.Context())
switch apiType {
case APITypeOpenAI:
if isStream {
err, responseText := openaiStreamHandler(c, resp, relayMode)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeClaude:
if isStream {
err, responseText := claudeStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeBaidu:
if isStream {
err, usage := baiduStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypePaLM:
if textRequest.Stream { // PaLM2 API does not support stream
err, responseText := palmStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
} else {
err, usage := zhipuHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
}
case APITypeAli:
if isStream {
err, usage := aliStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeXunfei:
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
case APITypeAIProxyLibrary:
if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := aiProxyLibraryHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeTencent:
if isStream {
err, responseText := tencentStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := tencentHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}
}

View File

@@ -1,188 +0,0 @@
package controller
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io"
"net/http"
"one-api/common"
"strconv"
"strings"
)
var stopFinishReason = "stop"
// tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model, _ := range common.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder
} else {
tokenEncoderMap[model] = nil
}
}
common.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil {
return tokenEncoder
}
if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
return defaultTokenEncoder
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if common.ApproximateTokenEnabled {
return int(float64(len(text)) * 0.38)
}
return len(tokenEncoder.Encode(text, nil, nil))
}
func countTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Content)
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum
}
func countTokenInput(input any, model string) int {
switch input.(type) {
case string:
return countTokenText(input.(string), model)
case []string:
text := ""
for _, s := range input.([]string) {
text += s
}
return countTokenText(text, model)
}
return 0
}
func countTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
}
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
openAIError := OpenAIError{
Message: err.Error(),
Type: "one_api_error",
Code: code,
}
return &OpenAIErrorWithStatusCode{
OpenAIError: openAIError,
StatusCode: statusCode,
}
}
func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
return false
}
func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var textResponse TextResponse
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return
}
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
return
}
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
}
}
return fullRequestURL
}

View File

@@ -1,306 +0,0 @@
package controller
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io"
"net/http"
"net/url"
"one-api/common"
"strings"
"time"
)
// https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html
type XunfeiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type XunfeiChatRequest struct {
Header struct {
AppId string `json:"app_id"`
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Message struct {
Text []XunfeiMessage `json:"text"`
} `json:"message"`
} `json:"payload"`
}
type XunfeiChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
type XunfeiChatResponse struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []XunfeiChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Content: message.Content,
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
Content: message.Content,
})
}
}
xunfeiRequest := XunfeiChatRequest{}
xunfeiRequest.Header.AppId = xunfeiAppId
xunfeiRequest.Parameter.Chat.Domain = domain
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
return &xunfeiRequest
}
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
FinishReason: stopFinishReason,
}
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: response.Payload.Usage.Text,
}
return &fullTextResponse
}
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &stopFinishReason
}
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "SparkDesk",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
HmacWithShaToBase64 := func(algorithm, data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
mac.Write([]byte(data))
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
ul, err := url.Parse(hostUrl)
if err != nil {
fmt.Println(err)
}
date := time.Now().UTC().Format(time.RFC1123)
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
sign := strings.Join(signString, "\n")
sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
"hmac-sha256", "host date request-line", sha)
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
v := url.Values{}
v.Add("host", ul.Host)
v.Add("date", date)
v.Add("authorization", authorization)
callUrl := hostUrl + "?" + v.Encode()
return callUrl
}
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
setEventStreamHeaders(c)
var usage Usage
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
}
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
var usage Usage
var content string
var xunfeiResponse XunfeiChatResponse
stop := false
for !stop {
select {
case xunfeiResponse = <-dataChan:
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
continue
}
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
case stop = <-stopChan:
}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
}
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
conn, resp, err := d.Dial(authUrl, nil)
if err != nil || resp.StatusCode != 101 {
return nil, nil, err
}
data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data)
if err != nil {
return nil, nil, err
}
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
for {
_, msg, err := conn.ReadMessage()
if err != nil {
common.SysError("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
common.SysError("error closing websocket connection: " + err.Error())
}
break
}
}
stopChan <- true
}()
return dataChan, stopChan, nil
}
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
if apiVersion == "" {
apiVersion = "v1.1"
common.SysLog("api_version not found, use default: " + apiVersion)
}
domain := "general"
if apiVersion == "v2.1" {
domain = "generalv2"
}
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl
}

View File

@@ -1,231 +1,132 @@
package controller package controller
import ( import (
"bytes"
"context"
"fmt" "fmt"
"net/http"
"one-api/common"
"strconv"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
type Message struct { "github.com/songquanpeng/one-api/common/helper"
Role string `json:"role"` "github.com/songquanpeng/one-api/common/logger"
Content string `json:"content"` "github.com/songquanpeng/one-api/middleware"
Name *string `json:"name,omitempty"` dbmodel "github.com/songquanpeng/one-api/model"
} "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
const ( "github.com/songquanpeng/one-api/relay/controller"
RelayModeUnknown = iota "github.com/songquanpeng/one-api/relay/model"
RelayModeChatCompletions "github.com/songquanpeng/one-api/relay/util"
RelayModeCompletions "io"
RelayModeEmbeddings "net/http"
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudio
) )
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
type GeneralOpenAIRequest struct { func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
Model string `json:"model,omitempty"` var err *model.ErrorWithStatusCode
Messages []Message `json:"messages,omitempty"` switch relayMode {
Prompt any `json:"prompt,omitempty"` case constant.RelayModeImagesGenerations:
Stream bool `json:"stream,omitempty"` err = controller.RelayImageHelper(c, relayMode)
MaxTokens int `json:"max_tokens,omitempty"` case constant.RelayModeAudioSpeech:
Temperature float64 `json:"temperature,omitempty"` fallthrough
TopP float64 `json:"top_p,omitempty"` case constant.RelayModeAudioTranslation:
N int `json:"n,omitempty"` fallthrough
Input any `json:"input,omitempty"` case constant.RelayModeAudioTranscription:
Instruction string `json:"instruction,omitempty"` err = controller.RelayAudioHelper(c, relayMode)
Size string `json:"size,omitempty"` default:
Functions any `json:"functions,omitempty"` err = controller.RelayTextHelper(c)
} }
return err
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
type ImageRequest struct {
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}
type AudioResponse struct {
Text string `json:"text,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
}
}
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
} }
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
relayMode := RelayModeUnknown ctx := c.Request.Context()
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { relayMode := constant.Path2RelayMode(c.Request.URL.Path)
relayMode = RelayModeChatCompletions if config.DebugEnabled {
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { requestBody, _ := common.GetRequestBody(c)
relayMode = RelayModeCompletions logger.Debugf(ctx, "request body: %s", string(requestBody))
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode = RelayModeAudio
} }
var err *OpenAIErrorWithStatusCode channelId := c.GetInt("channel_id")
switch relayMode { bizErr := relay(c, relayMode)
case RelayModeImagesGenerations: if bizErr == nil {
err = relayImageHelper(c, relayMode) monitor.Emit(channelId, true)
case RelayModeAudio: return
err = relayAudioHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)
} }
lastFailedChannelId := channelId
channelName := c.GetString("channel_name")
group := c.GetString("group")
originalModel := c.GetString("original_model")
go processChannelRelayError(ctx, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey)
retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) {
logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
retryTimes = 0
}
for i := retryTimes; i > 0; i-- {
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
if err != nil { if err != nil {
requestId := c.GetString(common.RequestIdKey) logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
retryTimesStr := c.Query("retry") break
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
} }
if retryTimes > 0 { logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) if channel.Id == lastFailedChannelId {
} else { continue
if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
} }
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) middleware.SetupContextForSelectedChannel(c, channel, originalModel)
c.JSON(err.StatusCode, gin.H{ requestBody, err := common.GetRequestBody(c)
"error": err.OpenAIError, c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
bizErr = relay(c, relayMode)
if bizErr == nil {
return
}
channelId := c.GetInt("channel_id")
lastFailedChannelId = channelId
channelName := c.GetString("channel_name")
go processChannelRelayError(ctx, channelId, channelName, bizErr)
}
if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
c.JSON(bizErr.StatusCode, gin.H{
"error": bizErr.Error,
}) })
} }
channelId := c.GetInt("channel_id")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message)
} }
func shouldRetry(c *gin.Context, statusCode int) bool {
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if statusCode == http.StatusTooManyRequests {
return true
}
if statusCode/100 == 5 {
return true
}
if statusCode == http.StatusBadRequest {
return false
}
if statusCode/100 == 2 {
return false
}
return true
}
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message)
} else {
monitor.Emit(channelId, false)
} }
} }
func RelayNotImplemented(c *gin.Context) { func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{ err := model.Error{
Message: "API not implemented", Message: "API not implemented",
Type: "one_api_error", Type: "one_api_error",
Param: "", Param: "",
@@ -237,7 +138,7 @@ func RelayNotImplemented(c *gin.Context) {
} }
func RelayNotFound(c *gin.Context) { func RelayNotFound(c *gin.Context) {
err := OpenAIError{ err := model.Error{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error", Type: "invalid_request_error",
Param: "", Param: "",

View File

@@ -2,9 +2,11 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
) )
@@ -14,7 +16,10 @@ func GetAllTokens(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage)
order := c.Query("order")
tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -119,9 +124,9 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Key: common.GenerateKey(), Key: helper.GenerateKey(),
CreatedTime: common.GetTimestamp(), CreatedTime: helper.GetTimestamp(),
AccessedTime: common.GetTimestamp(), AccessedTime: helper.GetTimestamp(),
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota, RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota, UnlimitedQuota: token.UnlimitedQuota,
@@ -137,6 +142,7 @@ func AddToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": cleanToken,
}) })
return return
} }
@@ -187,7 +193,7 @@ func UpdateToken(c *gin.Context) {
return return
} }
if token.Status == common.TokenStatusEnabled { if token.Status == common.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",

View File

@@ -3,10 +3,13 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -18,7 +21,7 @@ type LoginRequest struct {
} }
func Login(c *gin.Context) { func Login(c *gin.Context) {
if !common.PasswordLoginEnabled { if !config.PasswordLoginEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了密码登录", "message": "管理员关闭了密码登录",
"success": false, "success": false,
@@ -105,14 +108,14 @@ func Logout(c *gin.Context) {
} }
func Register(c *gin.Context) { func Register(c *gin.Context) {
if !common.RegisterEnabled { if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册", "message": "管理员关闭了新用户注册",
"success": false, "success": false,
}) })
return return
} }
if !common.PasswordRegisterEnabled { if !config.PasswordRegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
"success": false, "success": false,
@@ -135,7 +138,7 @@ func Register(c *gin.Context) {
}) })
return return
} }
if common.EmailVerificationEnabled { if config.EmailVerificationEnabled {
if user.Email == "" || user.VerificationCode == "" { if user.Email == "" || user.VerificationCode == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -159,7 +162,7 @@ func Register(c *gin.Context) {
DisplayName: user.Username, DisplayName: user.Username,
InviterId: inviterId, InviterId: inviterId,
} }
if common.EmailVerificationEnabled { if config.EmailVerificationEnabled {
cleanUser.Email = user.Email cleanUser.Email = user.Email
} }
if err := cleanUser.Insert(inviterId); err != nil { if err := cleanUser.Insert(inviterId); err != nil {
@@ -181,7 +184,10 @@ func GetAllUsers(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage)
order := c.DefaultQuery("order", "")
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -189,12 +195,12 @@ func GetAllUsers(c *gin.Context) {
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": users, "data": users,
}) })
return
} }
func SearchUsers(c *gin.Context) { func SearchUsers(c *gin.Context) {
@@ -248,6 +254,29 @@ func GetUser(c *gin.Context) {
return return
} }
func GetUserDashboard(c *gin.Context) {
id := c.GetInt("id")
now := time.Now()
startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法获取统计信息",
"data": nil,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": dashboards,
})
return
}
func GenerateAccessToken(c *gin.Context) { func GenerateAccessToken(c *gin.Context) {
id := c.GetInt("id") id := c.GetInt("id")
user, err := model.GetUserById(id, true) user, err := model.GetUserById(id, true)
@@ -258,7 +287,7 @@ func GenerateAccessToken(c *gin.Context) {
}) })
return return
} }
user.AccessToken = common.GetUUID() user.AccessToken = helper.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -295,7 +324,7 @@ func GetAffCode(c *gin.Context) {
return return
} }
if user.AffCode == "" { if user.AffCode == "" {
user.AffCode = common.GetRandomString(4) user.AffCode = helper.GetRandomString(4)
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -702,7 +731,7 @@ func EmailBind(c *gin.Context) {
return return
} }
if user.Role == common.RoleRootUser { if user.Role == common.RoleRootUser {
common.RootUserEmail = email config.RootUserEmail = email
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,

View File

@@ -5,9 +5,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"time" "time"
) )
@@ -22,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) {
if code == "" { if code == "" {
return "", errors.New("无效的参数") return "", errors.New("无效的参数")
} }
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil)
if err != nil { if err != nil {
return "", err return "", err
} }
req.Header.Set("Authorization", common.WeChatServerToken) req.Header.Set("Authorization", config.WeChatServerToken)
client := http.Client{ client := http.Client{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
} }
@@ -50,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) {
} }
func WeChatAuth(c *gin.Context) { func WeChatAuth(c *gin.Context) {
if !common.WeChatAuthEnabled { if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
"success": false, "success": false,
@@ -79,7 +80,7 @@ func WeChatAuth(c *gin.Context) {
return return
} }
} else { } else {
if common.RegisterEnabled { if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User" user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser user.Role = common.RoleCommonUser
@@ -112,7 +113,7 @@ func WeChatAuth(c *gin.Context) {
} }
func WeChatBind(c *gin.Context) { func WeChatBind(c *gin.Context) {
if !common.WeChatAuthEnabled { if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
"success": false, "success": false,

View File

@@ -2,26 +2,26 @@ version: '3.4'
services: services:
one-api: one-api:
image: justsong/one-api:latest image: "${REGISTRY:-docker.io}/justsong/one-api:latest"
container_name: one-api container_name: one-api
restart: always restart: always
command: --log-dir /app/logs command: --log-dir /app/logs
ports: ports:
- "3000:3000" - "3000:3000"
volumes: volumes:
- ./data:/data - ./data/oneapi:/data
- ./logs:/app/logs - ./logs:/app/logs
environment: environment:
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 - SQL_DSN=oneapi:123456@tcp(db:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串 - SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
# - NODE_TYPE=slave # 多机部署时从节点取消注释该行 # - NODE_TYPE=slave # 多机部署时从节点取消注释该行
# - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行 # - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行
# - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行 # - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行
depends_on: depends_on:
- redis - redis
- db
healthcheck: healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s interval: 30s
@@ -29,6 +29,21 @@ services:
retries: 3 retries: 3
redis: redis:
image: redis:latest image: "${REGISTRY:-docker.io}/redis:latest"
container_name: redis container_name: redis
restart: always restart: always
db:
image: "${REGISTRY:-docker.io}/mysql:8.2.0"
restart: always
container_name: mysql
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
ports:
- '3306:3306'
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: oneapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: one-api # 自动创建数据库

18
go.mod
View File

@@ -1,4 +1,4 @@
module one-api module github.com/songquanpeng/one-api
// +heroku goVersion go1.18 // +heroku goVersion go1.18
go 1.18 go 1.18
@@ -15,7 +15,9 @@ require (
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5 github.com/pkoukk/tiktoken-go v0.1.5
golang.org/x/crypto v0.14.0 github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.17.0
golang.org/x/image v0.14.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2 gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3 gorm.io/driver/sqlite v1.4.3
@@ -26,6 +28,7 @@ require (
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
@@ -39,7 +42,8 @@ require (
github.com/gorilla/sessions v1.2.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect github.com/jackc/pgx/v5 v5.5.4 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
@@ -50,12 +54,14 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.17.0 // indirect golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect golang.org/x/sync v0.1.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/sys v0.15.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

26
go.sum
View File

@@ -73,8 +73,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@@ -150,11 +152,15 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -162,21 +168,21 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

View File

@@ -8,12 +8,12 @@
"确认删除": "Confirm Delete", "确认删除": "Confirm Delete",
"确认绑定": "Confirm Binding", "确认绑定": "Confirm Binding",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.",
"\"道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"", "\"道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s", "道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"测试已在运行中": "Test is already running", "测试已在运行中": "Test is already running",
"响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs", "响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs",
"道测试完成": "Channel test completed", "道测试完成": "Channel test completed",
"道测试完成,如果没有收到禁用通知,说明所有道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", "道测试完成,如果没有收到禁用通知,说明所有道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal",
"无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!", "无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!",
"返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!", "返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!",
"管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub", "管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub",
@@ -86,6 +86,7 @@
"该令牌已过期": "The token has expired", "该令牌已过期": "The token has expired",
"该令牌额度已用尽": "The token quota has been used up", "该令牌额度已用尽": "The token quota has been used up",
"无效的令牌": "Invalid token", "无效的令牌": "Invalid token",
"令牌验证失败": "Token verification failed",
"id 或 userId 为空!": "id or userId is empty!", "id 或 userId 为空!": "id or userId is empty!",
"quota 不能为负数!": "quota cannot be negative!", "quota 不能为负数!": "quota cannot be negative!",
"令牌额度不足": "Insufficient token quota", "令牌额度不足": "Insufficient token quota",
@@ -118,10 +119,11 @@
" 个月 ": " M ", " 个月 ": " M ",
" 年 ": " y ", " 年 ": " y ",
"未测试": "Not tested", "未测试": "Not tested",
"道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", "道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
"已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", "已成功开始测试所有道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", "已成功开始测试所有已启用渠道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", "渠道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用渠道余额!": "The balance of all enabled channels has been updated!",
"搜索渠道的 ID名称和密钥 ...": "Search for channel ID, name and key ...", "搜索渠道的 ID名称和密钥 ...": "Search for channel ID, name and key ...",
"名称": "Name", "名称": "Name",
"分组": "Group", "分组": "Group",
@@ -139,8 +141,9 @@
"启用": "Enable", "启用": "Enable",
"编辑": "Edit", "编辑": "Edit",
"添加新的渠道": "Add a new channel", "添加新的渠道": "Add a new channel",
"测试所有已启用通道": "Test all enabled channels", "测试所有道": "Test all channels",
"更新所有已启用通道余额": "Update the balance of all enabled channels", "测试所有已启用渠道": "Test all enabled channels",
"更新所有已启用渠道余额": "Update the balance of all enabled channels",
"刷新": "Refresh", "刷新": "Refresh",
"处理中...": "Processing...", "处理中...": "Processing...",
"绑定成功!": "Binding succeeded!", "绑定成功!": "Binding succeeded!",
@@ -204,11 +207,11 @@
"监控设置": "Monitoring Settings", "监控设置": "Monitoring Settings",
"最长响应时间": "Longest Response Time", "最长响应时间": "Longest Response Time",
"单位秒": "Unit in seconds", "单位秒": "Unit in seconds",
"当运行道全部测试时": "When all operating channels are tested", "当运行道全部测试时": "When all operating channels are tested",
"超过此时间将自动禁用道": "Channels will be automatically disabled if this time is exceeded", "超过此时间将自动禁用道": "Channels will be automatically disabled if this time is exceeded",
"额度提醒阈值": "Quota reminder threshold", "额度提醒阈值": "Quota reminder threshold",
"低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this", "低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this",
"失败时自动禁用道": "Automatically disable the channel when it fails", "失败时自动禁用道": "Automatically disable the channel when it fails",
"保存监控设置": "Save Monitoring Settings", "保存监控设置": "Save Monitoring Settings",
"额度设置": "Quota Settings", "额度设置": "Quota Settings",
"新用户初始额度": "Initial quota for new users", "新用户初始额度": "Initial quota for new users",
@@ -402,7 +405,7 @@
"镜像": "Mirror", "镜像": "Mirror",
"请输入镜像站地址格式为https://domain.com可不填不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used", "请输入镜像站地址格式为https://domain.com可不填不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used",
"模型": "Model", "模型": "Model",
"请选择该道所支持的模型": "Please select the model supported by the channel", "请选择该道所支持的模型": "Please select the model supported by the channel",
"填入基础模型": "Fill in the basic model", "填入基础模型": "Fill in the basic model",
"填入所有模型": "Fill in all models", "填入所有模型": "Fill in all models",
"清除所有模型": "Clear all models", "清除所有模型": "Clear all models",
@@ -453,9 +456,11 @@
"已绑定的邮箱账户": "Email Account Bound", "已绑定的邮箱账户": "Email Account Bound",
"用户信息更新成功!": "User information updated successfully!", "用户信息更新成功!": "User information updated successfully!",
"模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f", "模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f",
"模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f": "model rate %.2f, group rate %.2f, completion rate %.2f",
"使用明细(总消耗额度:{renderQuota(stat.quota)}": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", "使用明细(总消耗额度:{renderQuota(stat.quota)}": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})",
"用户名称": "User Name", "用户名称": "User Name",
"令牌名称": "Token Name", "令牌名称": "Token Name",
"默认令牌": "Default Token",
"留空则查询全部用户": "Leave blank to query all users", "留空则查询全部用户": "Leave blank to query all users",
"留空则查询全部令牌": "Leave blank to query all tokens", "留空则查询全部令牌": "Leave blank to query all tokens",
"模型名称": "Model Name", "模型名称": "Model Name",
@@ -510,7 +515,7 @@
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
"Homepage URL 填": "Fill in the Homepage URL", "Homepage URL 填": "Fill in the Homepage URL",
"Authorization callback URL 填": "Fill in the Authorization callback URL", "Authorization callback URL 填": "Fill in the Authorization callback URL",
"请为道命名": "Please name the channel", "请为道命名": "Please name the channel",
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
"模型重定向": "Model redirection", "模型重定向": "Model redirection",
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",
@@ -524,5 +529,250 @@
"模型版本": "Model version", "模型版本": "Model version",
"请输入星火大模型版本注意是接口地址中的版本号例如v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", "请输入星火大模型版本注意是接口地址中的版本号例如v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
"点击查看": "click to view", "点击查看": "click to view",
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!" "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!",
"处理中...": "Processing...",
"绑定成功!": "Binding successful!",
"登录成功!": "Login successful!",
"操作失败,重定向至登录界面中...": "Operation failed, redirecting to login screen...",
"出现错误,第 ${count} 次重试中...": "An error occurred, retrying ${count}...",
"首页": "Home",
"渠道": "Channel",
"令牌": "API Keys",
"兑换": "Redeem",
"充值": "Recharge",
"用户": "Users",
"日志": "Logs",
"设置": "Settings",
"关于": "About",
"聊天": "Chat",
"注销成功!": "Logout successful!",
"注销": "Log out",
"登录": "Log in",
"注册": "Sign up",
"加载{name}中...": "Loading {name}...",
"未登录或登录已过期,请重新登录!": "Not logged in or login has expired, please log in again!",
"请立刻修改默认密码!": "Please change the default password immediately!",
"欢迎回来": "Welcome back",
"没有账户?": "No account?",
"立刻注册": "Sign up now",
"用户名": "Username",
"密码": "Password",
"正在登录……": "Logging in...",
"忘记密码": "Forgot password",
"其他方式": "Other methods",
"微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "Scan the QR code with WeChat, follow the official account and enter 'verification code' to get the verification code (valid within three minutes)",
"验证码": "Verification code",
"全部用户": "All users",
"当前用户": "Current user",
"全部": "All",
"消费": "Consumption",
"管理": "Management",
"系统": "System",
"未知": "Unknown",
"其他模型": "Other models",
"复制成功": "Copy successful",
"使用明细": "Usages",
"刷新": "Refresh",
"收起面板": "Collapse panel",
"展开面板": "Expand panel",
"显示查询选项": "Show search options",
"隐藏查询选项": "Hide search options",
"用户名称": "User name",
"可选值": "Optional values",
"渠道 ID": "Channel ID",
"令牌名称": "Key name",
"模型名称": "Model name",
"起始时间": "Start time",
"结束时间": "End time",
"查询": "Query",
"隐藏条形图": "Hide bar chart",
"显示条形图": "Show bar chart",
"折线条形图只展示最新50条数据": "Line and bar charts only show the latest 50 pieces of data",
"总消耗": "Total consumption",
"总共调用了 {payload[0].value} 次": "A total of {payload[0].value} calls were made",
"{model.name}: {model.value} 次": "{model.name}: {model.value} times",
"总共调用了 {payload[0].value} 次 {payload[0].name}": "A total of {payload[0].value} {payload[0].name} calls were made",
"总消耗额度": "Total consumption limit",
"暂无数据": "No data available",
"更多数据统计图形即将到来,敬请期待!": "More data statistics graphics are coming soon, stay tuned!",
"复制用户名": "Copy username",
"{`共 ${counts} 条数据`}": "{`A total of ${counts} pieces of data`}",
"共 0 条数据": "A total of 0 pieces of data",
"选择明细分类": "Select detail category",
"模型倍率": "model rate",
"分组倍率": "group rate",
"新密码已复制到剪贴板:": "New password has been copied to the clipboard:",
"密码重置确认": "Password reset confirmation",
"邮箱地址": "Email address",
"新密码": "New password",
"密码已复制到剪贴板:": "Password has been copied to the clipboard:",
"密码重置完成": "Password reset complete",
"提交": "Submit",
"返回登录": "Return to login",
"请稍后重试,浏览器环境检查未通过": "Please try again later, browser environment check failed",
"重置邮件发送成功,请检查邮箱!": "Reset email sent successfully, please check your email!",
"密码重置": "Password reset",
"重试": "Retry",
"组": "Group",
"令牌已重置并已复制到剪贴板": "Token has been reset and copied to the clipboard",
"邀请链接已复制到剪切板": "Invitation link has been copied to the clipboard",
"系统令牌已复制到剪切板": "System token has been copied to the clipboard",
"请输入你的账户名以确认删除!": "Please enter your account name to confirm deletion!",
"账户已删除!": "Account has been deleted!",
"微信账户绑定成功!": "WeChat account binding successful!",
"请稍后几秒重试Turnstile 正在检查用户环境!": "Please try again in a few seconds, Turnstile is checking the user environment!",
"验证码发送成功,请检查邮箱!": "Verification code sent successfully, please check your email!",
"邮箱账户绑定成功!": "Email account binding successful!",
"个人信息": "Personal information",
"编辑个人信息": "Edit personal information",
"生成系统访问令牌": "Generate system access token",
"复制邀请链接": "Copy invitation link",
"删除个人帐户": "Delete personal account",
"普通用户": "Regular user",
"管理员": "Administrator",
"超级管理员": "Super administrator",
"显示名称": "Display name",
"GitHub 账号": "GitHub account",
"微信账号": "WeChat account",
"修改个人信息只允许在电脑端进行。生成的令牌用于系统管理,而非用于请求 OpenAI 相关的服务,请知悉。": "Modifying personal information is only allowed on a computer. The generated token is for system management, not for requesting OpenAI related services. Please be aware.",
"可用模型": "Available models",
"账号绑定": "Account binding",
"绑定微信": "Bind WeChat",
"绑定 GitHub": "Bind GitHub",
"绑定邮箱": "Bind Email",
"绑定": "Bind",
"绑定邮箱地址": "Bind email address",
"输入邮箱地址": "Enter email address",
"重新发送": "Resend",
"获取验证码": "Get verification code",
"确认绑定": "Confirm binding",
"取消": "Cancel",
"危险操作": "Dangerous operation",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your own account, all data will be cleared and cannot be recovered",
"输入你的账户名": "Enter your account name",
"以确认删除": "To confirm deletion",
"确认删除": "Confirm deletion",
"未使用": "Not used",
"已禁用": "Disabled",
"已使用": "Used",
"未知状态": "Unknown status",
"操作成功完成!": "Operation successfully completed!",
"搜索兑换码的 ID 和名称 ...": "Search for the ID and name of the redemption code ...",
"名称": "Name",
"状态": "Status",
"额度": "Quota",
"创建时间": "Creation time",
"兑换时间": "Redemption time",
"操作": "Operation",
"尚未兑换": "Not yet redeemed",
"已复制到剪贴板!": "Copied to clipboard!",
"无法复制到剪贴板,请手动复制,已将兑换码填入搜索框。": "Unable to copy to clipboard, please copy manually. The redemption code has been filled in the search box.",
"复制": "Copy",
"删除": "Delete",
"禁用": "Disable",
"启用": "Enable",
"编辑": "Edit",
"添加新的兑换码": "Add new redemption code",
"密码长度不得小于 8 位!": "Password length must not be less than 8 characters!",
"两次输入的密码不一致": "The two passwords entered do not match",
"注册成功!": "Registration successful!",
"请填写注册邮箱!": "Please fill in the registration email!",
"请在${verificationTimeout}秒后再试": "Please try again after ${verificationTimeout} seconds",
"验证码发送成功,请检查你的邮箱!": "Verification code sent successfully, please check your email!",
"已有账户?": "Already have an account?",
"请输入用户名(最长 12 位)": "Please enter a username (up to 12 characters)",
"请输入密码(最短 8 位,最长 20 位)": "Please enter a password (minimum 8 characters, maximum 20 characters)",
"请再次输入密码": "Please enter the password again",
"请输入邮箱地址": "Please enter an email address",
"秒后可重发": "Can be resent after seconds",
"请输入邮箱验证码": "Please enter the email verification code",
"已过期": "Expired",
"已启用": "Enabled",
"已耗尽": "Exhausted",
"无": "None",
"令牌密钥": "API Key",
"令牌状态": "Key status",
"已用额度": "Used quota",
"剩余额度": "Remaining quota",
"过期时间": "Expiration time",
"你确定要删除这个令牌吗?": "Are you sure you want to delete this key?",
"无法复制到剪贴板,请手动复制,已将令牌密钥填入搜索框": "Unable to copy to clipboard, please copy manually. The key key has been filled in the search box.",
"无限制": "Unlimited",
"永不过期": "Never expires",
"使用 API 访问令牌进行服务鉴权和计费。": "Use API Key for service authentication and billing.",
"API 访问令牌关系到您的个人利益,请妥善留存,不要与其他人共享,也不要保存在客户端代码中。": "API Key is related to your personal interests. Please keep it properly. Do not share it with others or save it in client code.",
"创建令牌": "Create Key",
"什么都还没有,快去创建一个令牌开始使用吧!": "Nothing yet, go create a key to start using!",
"你确定要删除该令牌吗": "Are you sure you want to delete this key",
"导出令牌信息": "Export key information",
"错误:未登录或登录已过期,请重新登录!": "Error: Not logged in or login has expired, please log in again!",
"错误:请求次数过多,请稍后再试!": "Error: Too many requests, please try again later!",
"错误:服务器内部错误,请联系管理员!": "Error: Server internal error, please contact the online customer service!",
"本站仅作演示之用,无服务端!": "This site is for demonstration purposes only, no server!",
"错误:": "Error:",
"加载首页内容失败...": "Failed to load homepage content...",
"系统状况": "System status",
"系统信息": "System information",
"系统信息总览": "System information overview",
"名称:": "Name:",
"版本:": "Version:",
"源码:": "Source code:",
"启动时间:": "Startup time:",
"系统配置": "System configuration",
"系统配置总览": "System configuration overview",
"邮箱验证:": "Email verification:",
"未启用": "Not enabled",
"Turnstile 用户校验:": "Turnstile user verification:",
"页面不存在": "Page does not exist",
"请检查你的浏览器地址是否正确": "Please check if your browser address is correct",
"个人设置": "Personal settings",
"运营设置": "Operations settings",
"系统设置": "System settings",
"其他设置": "Other settings",
"默认令牌": "Default key",
"过期时间必须在当前时间之后!": "Expiration time must be after the current time!",
"额度必须大于等于 0": "Quota must be greater than or equal to 0!",
"过期时间格式错误!": "Expiration time format error!",
"创建令牌数量必须大于等于 1": "The number of keys to create must be greater than or equal to 1!",
"令牌修改成功": "API Key modification successful",
"令牌创建成功": "API Key creation successful",
"更新令牌信息": "Update key information",
"创建新的令牌": "Create a new key",
"请输入名称": "Please enter a name",
"请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss-1 表示无限制": "Please enter the expiration time, the format is yyyy-MM-dd HH:mm:ss, -1 means unlimited",
"无限额度": "Unlimited quota",
"注意:启用无限额度后,已用额度将不再进行计算。": "Note: After enabling unlimited quota, the used quota will no longer be calculated.",
"等于": "Equals",
"请输入额度单位token": "Please enter the quota (unit: token)",
"创建令牌数量": "Create key quantity",
"请输入令牌数量": "Please enter the number of keys",
"注意:令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note: The quota of the key is only used to limit the maximum quota usage of the key itself, and the actual usage is subject to the remaining quota of the account.",
"我的令牌": "My keys",
"请输入额度兑换码!": "Please enter the redeem code!",
"充值成功!": "Recharge successful!",
"请求失败": "Request failed",
"超级管理员未设置充值链接!": "The super administrator did not set a recharge link!",
"充值额度": "Recharge quota",
"兑换中...": "Redeeming...",
"请点击充值以获取额度兑换码。": "Please click recharge to get the quota redemption code.",
"用户信息更新成功!": "User information updated successfully!",
"更新用户信息": "Update user information",
"请输入新的用户名": "Please enter a new username",
"请输入新的密码,最短 8 位": "Please enter a new password, at least 8 characters",
"请输入新的显示名称": "Please enter a new display name",
"分组": "Group",
"请选择分组": "Please select a group",
"请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit the group rate on the system settings page to add a new group:",
"请输入新的剩余额度": "Please enter a new remaining quota",
"已绑定的 GitHub 账户": "Bound GitHub account",
"此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "This item is read-only, users need to bind through the relevant binding button on the personal settings page, cannot be directly modified",
"已绑定的微信账户": "Bound WeChat account",
"已绑定的邮箱账户": "Bound email account",
"新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面": "New version available: ${data.version}, please refresh the page using the shortcut key Shift + F5",
"无法正常连接至服务器!": "Unable to connect to the server normally!",
"提示:": "Input:",
"补全:": "Output:",
"搜索令牌名称": "Search key name",
"测试所有渠道": "Test all channels",
"更新已启用渠道余额": "Update the balance of enabled channels"
} }

87
main.go
View File

@@ -6,83 +6,94 @@ import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie" "github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "github.com/songquanpeng/one-api/common"
"one-api/controller" "github.com/songquanpeng/one-api/common/config"
"one-api/middleware" "github.com/songquanpeng/one-api/common/logger"
"one-api/model" "github.com/songquanpeng/one-api/controller"
"one-api/router" "github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/router"
"os" "os"
"strconv" "strconv"
) )
//go:embed web/build //go:embed web/build/*
var buildFS embed.FS var buildFS embed.FS
//go:embed web/build/index.html
var indexPage []byte
func main() { func main() {
common.SetupLogger() logger.SetupLogger()
common.SysLog("One API " + common.Version + " started") logger.SysLog(fmt.Sprintf("One API %s started", common.Version))
if os.Getenv("GIN_MODE") != "debug" { if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
if common.DebugEnabled { if config.DebugEnabled {
common.SysLog("running in debug mode") logger.SysLog("running in debug mode")
} }
var err error
// Initialize SQL Database // Initialize SQL Database
err := model.InitDB() model.DB, err = model.InitDB("SQL_DSN")
if err != nil { if err != nil {
common.FatalLog("failed to initialize database: " + err.Error()) logger.FatalLog("failed to initialize database: " + err.Error())
}
if os.Getenv("LOG_SQL_DSN") != "" {
logger.SysLog("using secondary database for table logs")
model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
}
} else {
model.LOG_DB = model.DB
}
err = model.CreateRootAccountIfNeed()
if err != nil {
logger.FatalLog("database init error: " + err.Error())
} }
defer func() { defer func() {
err := model.CloseDB() err := model.CloseDB()
if err != nil { if err != nil {
common.FatalLog("failed to close database: " + err.Error()) logger.FatalLog("failed to close database: " + err.Error())
} }
}() }()
// Initialize Redis // Initialize Redis
err = common.InitRedisClient() err = common.InitRedisClient()
if err != nil { if err != nil {
common.FatalLog("failed to initialize Redis: " + err.Error()) logger.FatalLog("failed to initialize Redis: " + err.Error())
} }
// Initialize options // Initialize options
model.InitOptionMap() model.InitOptionMap()
logger.SysLog(fmt.Sprintf("using theme %s", config.Theme))
if common.RedisEnabled { if common.RedisEnabled {
// for compatibility with old versions // for compatibility with old versions
common.MemoryCacheEnabled = true config.MemoryCacheEnabled = true
} }
if common.MemoryCacheEnabled { if config.MemoryCacheEnabled {
common.SysLog("memory cache enabled") logger.SysLog("memory cache enabled")
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency))
model.InitChannelCache() model.InitChannelCache()
} }
if common.MemoryCacheEnabled { if config.MemoryCacheEnabled {
go model.SyncOptions(common.SyncFrequency) go model.SyncOptions(config.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency) go model.SyncChannelCache(config.SyncFrequency)
}
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyUpdateChannels(frequency)
} }
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil { if err != nil {
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
} }
go controller.AutomaticallyTestChannels(frequency) go controller.AutomaticallyTestChannels(frequency)
} }
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true config.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
model.InitBatchUpdater() model.InitBatchUpdater()
} }
controller.InitTokenEncoders() if config.EnableMetric {
logger.SysLog("metric enabled, will disable channel if too much request failed")
}
openai.InitTokenEncoders()
// Initialize HTTP server // Initialize HTTP server
server := gin.New() server := gin.New()
@@ -92,16 +103,16 @@ func main() {
server.Use(middleware.RequestId()) server.Use(middleware.RequestId())
middleware.SetUpLogger(server) middleware.SetUpLogger(server)
// Initialize session store // Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret)) store := cookie.NewStore([]byte(config.SessionSecret))
server.Use(sessions.Sessions("session", store)) server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS, indexPage) router.SetRouter(server, buildFS)
var port = os.Getenv("PORT") var port = os.Getenv("PORT")
if port == "" { if port == "" {
port = strconv.Itoa(*common.Port) port = strconv.Itoa(*common.Port)
} }
err = server.Run(":" + port) err = server.Run(":" + port)
if err != nil { if err != nil {
common.FatalLog("failed to start HTTP server: " + err.Error()) logger.FatalLog("failed to start HTTP server: " + err.Error())
} }
} }

View File

@@ -3,9 +3,10 @@ package middleware
import ( import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strings" "strings"
) )
@@ -42,11 +43,14 @@ func authHelper(c *gin.Context, minRole int) {
return return
} }
} }
if status.(int) == common.UserStatusDisabled { if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "用户已被封禁", "message": "用户已被封禁",
}) })
session := sessions.Default(c)
session.Clear()
_ = session.Save()
c.Abort() c.Abort()
return return
} }
@@ -99,22 +103,16 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusInternalServerError, err.Error()) abortWithMessage(c, http.StatusInternalServerError, err.Error())
return return
} }
if !userEnabled { if !userEnabled || blacklist.IsUserBanned(token.UserId) {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁") abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return return
} }
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("specific_channel_id", parts[1])
} else { } else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return return

View File

@@ -2,9 +2,10 @@ package middleware
import ( import (
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"strings" "strings"
@@ -20,8 +21,9 @@ func Distribute() func(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
userGroup, _ := model.CacheGetUserGroup(userId) userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup) c.Set("group", userGroup)
var requestModel string
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("channelId") channelId, ok := c.Get("specific_channel_id")
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
if err != nil { if err != nil {
@@ -40,10 +42,7 @@ func Distribute() func(c *gin.Context) {
} else { } else {
// Select a channel for the user // Select a channel for the user
var modelRequest ModelRequest var modelRequest ModelRequest
var err error err := common.UnmarshalBodyReusable(c, &modelRequest)
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil { if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求") abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return return
@@ -60,39 +59,54 @@ func Distribute() func(c *gin.Context) {
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "dall-e" modelRequest.Model = "dall-e-2"
} }
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "whisper-1" modelRequest.Model = "whisper-1"
} }
} }
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
if err != nil { if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil { if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员" message = "数据库一致性已被破坏,请联系管理员"
} }
abortWithMessage(c, http.StatusServiceUnavailable, message) abortWithMessage(c, http.StatusServiceUnavailable, message)
return return
} }
} }
SetupContextForSelectedChannel(c, channel, requestModel)
c.Next()
}
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping()) c.Set("model_mapping", channel.GetModelMapping())
c.Set("original_model", modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
c.Set("api_version", channel.Other) c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei: case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other) c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary: case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other) c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli:
c.Set(common.ConfigKeyPlugin, channel.Other)
} }
c.Next() cfg, _ := channel.LoadConfig()
for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v)
} }
} }

View File

@@ -3,14 +3,14 @@ package middleware
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "github.com/songquanpeng/one-api/common/logger"
) )
func SetUpLogger(server *gin.Engine) { func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string var requestID string
if param.Keys != nil { if param.Keys != nil {
requestID = param.Keys[common.RequestIdKey].(string) requestID = param.Keys[logger.RequestIdKey].(string)
} }
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"), param.TimeStamp.Format("2006/01/02 - 15:04:05"),

View File

@@ -4,8 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"net/http" "net/http"
"one-api/common"
"time" "time"
) )
@@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
} }
if listLength < int64(maxRequestNum) { if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} else { } else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr) oldTime, err := time.Parse(timeFormat, oldTimeStr)
@@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
// time.Since will return negative number! // time.Since will return negative number!
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
if int64(nowTime.Sub(oldTime).Seconds()) < duration { if int64(nowTime.Sub(oldTime).Seconds()) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests) c.Status(http.StatusTooManyRequests)
c.Abort() c.Abort()
return return
} else { } else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} }
} }
} }
@@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
} }
} else { } else {
// It's safe to call multi times. // It's safe to call multi times.
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
return func(c *gin.Context) { return func(c *gin.Context) {
memoryRateLimiter(c, maxRequestNum, duration, mark) memoryRateLimiter(c, maxRequestNum, duration, mark)
} }
@@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
} }
func GlobalWebRateLimit() func(c *gin.Context) { func GlobalWebRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW")
} }
func GlobalAPIRateLimit() func(c *gin.Context) { func GlobalAPIRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA")
} }
func CriticalRateLimit() func(c *gin.Context) { func CriticalRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT")
} }
func DownloadRateLimit() func(c *gin.Context) { func DownloadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW")
} }
func UploadRateLimit() func(c *gin.Context) { func UploadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP")
} }

33
middleware/recover.go Normal file
View File

@@ -0,0 +1,33 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"net/http"
"runtime/debug"
)
func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
ctx := c.Request.Context()
logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err))
logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path))
body, _ := common.GetRequestBody(c)
logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body)))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err),
"type": "one_api_panic",
},
})
c.Abort()
}
}()
c.Next()
}
}

View File

@@ -3,16 +3,17 @@ package middleware
import ( import (
"context" "context"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
) )
func RequestId() func(c *gin.Context) { func RequestId() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
id := common.GetTimeString() + common.GetRandomString(8) id := helper.GenRequestID()
c.Set(common.RequestIdKey, id) c.Set(logger.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
c.Header(common.RequestIdKey, id) c.Header(logger.RequestIdKey, id)
c.Next() c.Next()
} }
} }

View File

@@ -4,9 +4,10 @@ import (
"encoding/json" "encoding/json"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"net/http" "net/http"
"net/url" "net/url"
"one-api/common"
) )
type turnstileCheckResponse struct { type turnstileCheckResponse struct {
@@ -15,7 +16,7 @@ type turnstileCheckResponse struct {
func TurnstileCheck() gin.HandlerFunc { func TurnstileCheck() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if common.TurnstileCheckEnabled { if config.TurnstileCheckEnabled {
session := sessions.Default(c) session := sessions.Default(c)
turnstileChecked := session.Get("turnstile") turnstileChecked := session.Get("turnstile")
if turnstileChecked != nil { if turnstileChecked != nil {
@@ -32,12 +33,12 @@ func TurnstileCheck() gin.HandlerFunc {
return return
} }
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
"secret": {common.TurnstileSecretKey}, "secret": {config.TurnstileSecretKey},
"response": {response}, "response": {response},
"remoteip": {c.ClientIP()}, "remoteip": {c.ClientIP()},
}) })
if err != nil { if err != nil {
common.SysError(err.Error()) logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
@@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc {
var res turnstileCheckResponse var res turnstileCheckResponse
err = json.NewDecoder(rawRes.Body).Decode(&res) err = json.NewDecoder(rawRes.Body).Decode(&res)
if err != nil { if err != nil {
common.SysError(err.Error()) logger.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -2,16 +2,17 @@ package middleware
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
) )
func abortWithMessage(c *gin.Context, statusCode int, message string) { func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"error": gin.H{ "error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
"type": "one_api_error", "type": "one_api_error",
}, },
}) })
c.Abort() c.Abort()
common.LogError(c.Request.Context(), message) logger.Error(c.Request.Context(), message)
} }

View File

@@ -1,7 +1,8 @@
package model package model
import ( import (
"one-api/common" "github.com/songquanpeng/one-api/common"
"gorm.io/gorm"
"strings" "strings"
) )
@@ -13,7 +14,7 @@ type Ability struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
} }
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
ability := Ability{} ability := Ability{}
groupCol := "`group`" groupCol := "`group`"
trueVal := "1" trueVal := "1"
@@ -23,8 +24,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
} }
var err error = nil var err error = nil
var channelQuery *gorm.DB
if ignoreFirstPriority {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
} else {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
}
if common.UsingSQLite || common.UsingPostgreSQL { if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error err = channelQuery.Order("RANDOM()").First(&ability).Error
} else { } else {

View File

@@ -1,11 +1,14 @@
package model package model
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"math/rand" "math/rand"
"one-api/common"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -14,10 +17,10 @@ import (
) )
var ( var (
TokenCacheSeconds = common.SyncFrequency TokenCacheSeconds = config.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency UserId2GroupCacheSeconds = config.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency UserId2QuotaCacheSeconds = config.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency UserId2StatusCacheSeconds = config.SyncFrequency
) )
func CacheGetTokenByKey(key string) (*Token, error) { func CacheGetTokenByKey(key string) (*Token, error) {
@@ -42,7 +45,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
} }
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil { if err != nil {
common.SysError("Redis set token error: " + err.Error()) logger.SysError("Redis set token error: " + err.Error())
} }
return &token, nil return &token, nil
} }
@@ -62,37 +65,48 @@ func CacheGetUserGroup(id int) (group string, err error) {
} }
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil { if err != nil {
common.SysError("Redis set user group error: " + err.Error()) logger.SysError("Redis set user group error: " + err.Error())
} }
} }
return group, err return group, err
} }
func CacheGetUserQuota(id int) (quota int, err error) { func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) {
if !common.RedisEnabled {
return GetUserQuota(id)
}
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
if err != nil {
quota, err = GetUserQuota(id) quota, err = GetUserQuota(id)
if err != nil { if err != nil {
return 0, err return 0, err
} }
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil { if err != nil {
common.SysError("Redis set user quota error: " + err.Error()) logger.Error(ctx, "Redis set user quota error: "+err.Error())
} }
return quota, err return
}
quota, err = strconv.Atoi(quotaString)
return quota, err
} }
func CacheUpdateUserQuota(id int) error { func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) {
if !common.RedisEnabled {
return GetUserQuota(id)
}
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
if err != nil {
return fetchAndUpdateUserQuota(ctx, id)
}
quota, err = strconv.ParseInt(quotaString, 10, 64)
if err != nil {
return 0, nil
}
if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db
logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id)
return fetchAndUpdateUserQuota(ctx, id)
}
return quota, nil
}
func CacheUpdateUserQuota(ctx context.Context, id int) error {
if !common.RedisEnabled { if !common.RedisEnabled {
return nil return nil
} }
quota, err := GetUserQuota(id) quota, err := CacheGetUserQuota(ctx, id)
if err != nil { if err != nil {
return err return err
} }
@@ -100,7 +114,7 @@ func CacheUpdateUserQuota(id int) error {
return err return err
} }
func CacheDecreaseUserQuota(id int, quota int) error { func CacheDecreaseUserQuota(id int, quota int64) error {
if !common.RedisEnabled { if !common.RedisEnabled {
return nil return nil
} }
@@ -127,7 +141,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
} }
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil { if err != nil {
common.SysError("Redis set user enabled error: " + err.Error()) logger.SysError("Redis set user enabled error: " + err.Error())
} }
return userEnabled, err return userEnabled, err
} }
@@ -178,20 +192,20 @@ func InitChannelCache() {
channelSyncLock.Lock() channelSyncLock.Lock()
group2model2channels = newGroup2model2channels group2model2channels = newGroup2model2channels
channelSyncLock.Unlock() channelSyncLock.Unlock()
common.SysLog("channels synced from database") logger.SysLog("channels synced from database")
} }
func SyncChannelCache(frequency int) { func SyncChannelCache(frequency int) {
for { for {
time.Sleep(time.Duration(frequency) * time.Second) time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing channels from database") logger.SysLog("syncing channels from database")
InitChannelCache() InitChannelCache()
} }
} }
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
if !common.MemoryCacheEnabled { if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model) return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority)
} }
channelSyncLock.RLock() channelSyncLock.RLock()
defer channelSyncLock.RUnlock() defer channelSyncLock.RUnlock()
@@ -211,5 +225,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
} }
} }
idx := rand.Intn(endIdx) idx := rand.Intn(endIdx)
if ignoreFirstPriority {
if endIdx < len(channels) { // which means there are more than one priority
idx = common.RandRange(endIdx, len(channels))
}
}
return channels[idx], nil return channels[idx], nil
} }

View File

@@ -1,14 +1,19 @@
package model package model
import ( import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common"
) )
type Channel struct { type Channel struct {
Id int `json:"id"` Id int `json:"id"`
Type int `json:"type" gorm:"default:0"` Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null;index"` Key string `json:"key" gorm:"type:text"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"` Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"` Weight *uint `json:"weight" gorm:"default:0"`
@@ -16,7 +21,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"` Other string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"` Models string `json:"models"`
@@ -24,25 +29,25 @@ type Channel struct {
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"`
} }
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
var channels []*Channel var channels []*Channel
var err error var err error
if selectAll { switch scope {
case "all":
err = DB.Order("id desc").Find(&channels).Error err = DB.Order("id desc").Find(&channels).Error
} else { case "disabled":
err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error
default:
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
} }
return channels, err return channels, err
} }
func SearchChannels(keyword string) (channels []*Channel, err error) { func SearchChannels(keyword string) (channels []*Channel, err error) {
keyCol := "`key`" err = DB.Omit("key").Where("id = ? or name LIKE ?", helper.String2Int(keyword), keyword+"%").Find(&channels).Error
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
return channels, err return channels, err
} }
@@ -86,11 +91,17 @@ func (channel *Channel) GetBaseURL() string {
return *channel.BaseURL return *channel.BaseURL
} }
func (channel *Channel) GetModelMapping() string { func (channel *Channel) GetModelMapping() map[string]string {
if channel.ModelMapping == nil { if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" {
return "" return nil
} }
return *channel.ModelMapping modelMapping := make(map[string]string)
err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping)
if err != nil {
logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error()))
return nil
}
return modelMapping
} }
func (channel *Channel) Insert() error { func (channel *Channel) Insert() error {
@@ -116,21 +127,21 @@ func (channel *Channel) Update() error {
func (channel *Channel) UpdateResponseTime(responseTime int64) { func (channel *Channel) UpdateResponseTime(responseTime int64) {
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
TestTime: common.GetTimestamp(), TestTime: helper.GetTimestamp(),
ResponseTime: int(responseTime), ResponseTime: int(responseTime),
}).Error }).Error
if err != nil { if err != nil {
common.SysError("failed to update response time: " + err.Error()) logger.SysError("failed to update response time: " + err.Error())
} }
} }
func (channel *Channel) UpdateBalance(balance float64) { func (channel *Channel) UpdateBalance(balance float64) {
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
BalanceUpdatedTime: common.GetTimestamp(), BalanceUpdatedTime: helper.GetTimestamp(),
Balance: balance, Balance: balance,
}).Error }).Error
if err != nil { if err != nil {
common.SysError("failed to update balance: " + err.Error()) logger.SysError("failed to update balance: " + err.Error())
} }
} }
@@ -144,29 +155,41 @@ func (channel *Channel) Delete() error {
return err return err
} }
func (channel *Channel) LoadConfig() (map[string]string, error) {
if channel.Config == "" {
return nil, nil
}
cfg := make(map[string]string)
err := json.Unmarshal([]byte(channel.Config), &cfg)
if err != nil {
return nil, err
}
return cfg, nil
}
func UpdateChannelStatusById(id int, status int) { func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil { if err != nil {
common.SysError("failed to update ability status: " + err.Error()) logger.SysError("failed to update ability status: " + err.Error())
} }
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil { if err != nil {
common.SysError("failed to update channel status: " + err.Error()) logger.SysError("failed to update channel status: " + err.Error())
} }
} }
func UpdateChannelUsedQuota(id int, quota int) { func UpdateChannelUsedQuota(id int, quota int64) {
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return return
} }
updateChannelUsedQuota(id, quota) updateChannelUsedQuota(id, quota)
} }
func updateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int64) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil { if err != nil {
common.SysError("failed to update channel used quota: " + err.Error()) logger.SysError("failed to update channel used quota: " + err.Error())
} }
} }

View File

@@ -3,14 +3,18 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common"
) )
type Log struct { type Log struct {
Id int `json:"id;index:idx_created_at_id,priority:1"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"` Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"` Content string `json:"content"`
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
@@ -31,52 +35,52 @@ const (
) )
func RecordLog(userId int, logType int, content string) { func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled { if logType == LogTypeConsume && !config.LogConsumeEnabled {
return return
} }
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
CreatedAt: common.GetTimestamp(), CreatedAt: helper.GetTimestamp(),
Type: logType, Type: logType,
Content: content, Content: content,
} }
err := DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
common.SysError("failed to record log: " + err.Error()) logger.SysError("failed to record log: " + err.Error())
} }
} }
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled { if !config.LogConsumeEnabled {
return return
} }
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
CreatedAt: common.GetTimestamp(), CreatedAt: helper.GetTimestamp(),
Type: LogTypeConsume, Type: LogTypeConsume,
Content: content, Content: content,
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TokenName: tokenName, TokenName: tokenName,
ModelName: modelName, ModelName: modelName,
Quota: quota, Quota: int(quota),
ChannelId: channelId, ChannelId: channelId,
} }
err := DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error()) logger.Error(ctx, "failed to record log: "+err.Error())
} }
} }
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = DB tx = LOG_DB
} else { } else {
tx = DB.Where("type = ?", logType) tx = LOG_DB.Where("type = ?", logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name = ?", modelName)
@@ -94,7 +98,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
if channel != 0 { if channel != 0 {
tx = tx.Where("channel = ?", channel) tx = tx.Where("channel_id = ?", channel)
} }
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err return logs, err
@@ -103,9 +107,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = DB.Where("user_id = ?", userId) tx = LOG_DB.Where("user_id = ?", userId)
} else { } else {
tx = DB.Where("user_id = ? and type = ?", userId, logType) tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name = ?", modelName)
@@ -124,17 +128,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
} }
func SearchAllLogs(keyword string) (logs []*Log, err error) { func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
return logs, err return logs, err
} }
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err return logs, err
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
tx := DB.Table("logs").Select("ifnull(sum(quota),0)") tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }
@@ -151,14 +155,14 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name = ?", modelName)
} }
if channel != 0 { if channel != 0 {
tx = tx.Where("channel = ?", channel) tx = tx.Where("channel_id = ?", channel)
} }
tx.Where("type = ?", LogTypeConsume).Scan(&quota) tx.Where("type = ?", LogTypeConsume).Scan(&quota)
return quota return quota
} }
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }
@@ -179,6 +183,43 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
} }
func DeleteOldLog(targetTimestamp int64) (int64, error) { func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
type LogStatistic struct {
Day string `gorm:"column:day"`
ModelName string `gorm:"column:model_name"`
RequestCount int `gorm:"column:request_count"`
Quota int `gorm:"column:quota"`
PromptTokens int `gorm:"column:prompt_tokens"`
CompletionTokens int `gorm:"column:completion_tokens"`
}
func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatistic, err error) {
groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day"
if common.UsingPostgreSQL {
groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day"
}
if common.UsingSQLite {
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
}
err = LOG_DB.Raw(`
SELECT `+groupSelect+`,
model_name, count(1) as request_count,
sum(quota) as quota,
sum(prompt_tokens) as prompt_tokens,
sum(completion_tokens) as completion_tokens
FROM logs
WHERE type=2
AND user_id= ?
AND created_at BETWEEN ? AND ?
GROUP BY day, model_name
ORDER BY day, model_name
`, userId, start, end).Scan(&LogStatistics).Error
return LogStatistics, err
}

View File

@@ -1,23 +1,29 @@
package model package model
import ( import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common"
"os" "os"
"strings" "strings"
"time" "time"
) )
var DB *gorm.DB var DB *gorm.DB
var LOG_DB *gorm.DB
func createRootAccountIfNeed() error { func CreateRootAccountIfNeed() error {
var user User var user User
//if user.Status != common.UserStatusEnabled { //if user.Status != util.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil { if err := DB.First(&user).Error; err != nil {
common.SysLog("no user exists, create a root user for you: username is root, password is 123456") logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456") hashedPassword, err := common.Password2Hash("123456")
if err != nil { if err != nil {
return err return err
@@ -28,20 +34,36 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser, Role: common.RoleRootUser,
Status: common.UserStatusEnabled, Status: common.UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: common.GetUUID(), AccessToken: helper.GetUUID(),
Quota: 100000000, Quota: 500000000000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
if config.InitialRootToken != "" {
logger.SysLog("creating initial root token as requested")
token := Token{
Id: 1,
UserId: rootUser.Id,
Key: config.InitialRootToken,
Status: common.TokenStatusEnabled,
Name: "Initial Root Token",
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: -1,
RemainQuota: 500000000000000,
UnlimitedQuota: true,
}
DB.Create(&token)
}
} }
return nil return nil
} }
func chooseDB() (*gorm.DB, error) { func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv("SQL_DSN") != "" { if os.Getenv(envName) != "" {
dsn := os.Getenv("SQL_DSN") dsn := os.Getenv(envName)
if strings.HasPrefix(dsn, "postgres://") { if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL // Use PostgreSQL
common.SysLog("using PostgreSQL as database") logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{ return gorm.Open(postgres.New(postgres.Config{
DSN: dsn, DSN: dsn,
@@ -51,80 +73,93 @@ func chooseDB() (*gorm.DB, error) {
}) })
} }
// Use MySQL // Use MySQL
common.SysLog("using MySQL as database") logger.SysLog("using MySQL as database")
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{ return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
// Use SQLite // Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database") logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
func InitDB() (err error) { func InitDB(envName string) (db *gorm.DB, err error) {
db, err := chooseDB() db, err = chooseDB(envName)
if err == nil { if err == nil {
if common.DebugEnabled { if config.DebugSQLEnabled {
db = db.Debug() db = db.Debug()
} }
DB = db sqlDB, err := db.DB()
sqlDB, err := DB.DB()
if err != nil { if err != nil {
return err return nil, err
} }
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode { if !config.IsMasterNode {
return nil return db, err
} }
common.SysLog("database migration started") if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
}
logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{}) err = db.AutoMigrate(&Channel{})
if err != nil { if err != nil {
return err return nil, err
} }
err = db.AutoMigrate(&Token{}) err = db.AutoMigrate(&Token{})
if err != nil { if err != nil {
return err return nil, err
} }
err = db.AutoMigrate(&User{}) err = db.AutoMigrate(&User{})
if err != nil { if err != nil {
return err return nil, err
} }
err = db.AutoMigrate(&Option{}) err = db.AutoMigrate(&Option{})
if err != nil { if err != nil {
return err return nil, err
} }
err = db.AutoMigrate(&Redemption{}) err = db.AutoMigrate(&Redemption{})
if err != nil { if err != nil {
return err return nil, err
} }
err = db.AutoMigrate(&Ability{}) err = db.AutoMigrate(&Ability{})
if err != nil { if err != nil {
return err return nil, err
} }
err = db.AutoMigrate(&Log{}) err = db.AutoMigrate(&Log{})
if err != nil { if err != nil {
return err return nil, err
} }
common.SysLog("database migrated") logger.SysLog("database migrated")
err = createRootAccountIfNeed() return db, err
return err
} else { } else {
common.FatalLog(err) logger.FatalLog(err)
} }
return err return db, err
} }
func CloseDB() error { func closeDB(db *gorm.DB) error {
sqlDB, err := DB.DB() sqlDB, err := db.DB()
if err != nil { if err != nil {
return err return err
} }
err = sqlDB.Close() err = sqlDB.Close()
return err return err
} }
func CloseDB() error {
if LOG_DB != DB {
err := closeDB(LOG_DB)
if err != nil {
return err
}
}
return closeDB(DB)
}

View File

@@ -1,7 +1,9 @@
package model package model
import ( import (
"one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -20,67 +22,71 @@ func AllOption() ([]*Option, error) {
} }
func InitOptionMap() { func InitOptionMap() {
common.OptionMapRWMutex.Lock() config.OptionMapRWMutex.Lock()
common.OptionMap = make(map[string]string) config.OptionMap = make(map[string]string)
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled)
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) config.OptionMap["SMTPServer"] = ""
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) config.OptionMap["SMTPFrom"] = ""
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
common.OptionMap["SMTPServer"] = "" config.OptionMap["SMTPAccount"] = ""
common.OptionMap["SMTPFrom"] = "" config.OptionMap["SMTPToken"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) config.OptionMap["Notice"] = ""
common.OptionMap["SMTPAccount"] = "" config.OptionMap["About"] = ""
common.OptionMap["SMTPToken"] = "" config.OptionMap["HomePageContent"] = ""
common.OptionMap["Notice"] = "" config.OptionMap["Footer"] = config.Footer
common.OptionMap["About"] = "" config.OptionMap["SystemName"] = config.SystemName
common.OptionMap["HomePageContent"] = "" config.OptionMap["Logo"] = config.Logo
common.OptionMap["Footer"] = common.Footer config.OptionMap["ServerAddress"] = ""
common.OptionMap["SystemName"] = common.SystemName config.OptionMap["GitHubClientId"] = ""
common.OptionMap["Logo"] = common.Logo config.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["ServerAddress"] = "" config.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["GitHubClientId"] = "" config.OptionMap["WeChatServerToken"] = ""
common.OptionMap["GitHubClientSecret"] = "" config.OptionMap["WeChatAccountQRCodeImageURL"] = ""
common.OptionMap["WeChatServerAddress"] = "" config.OptionMap["MessagePusherAddress"] = ""
common.OptionMap["WeChatServerToken"] = "" config.OptionMap["MessagePusherToken"] = ""
common.OptionMap["WeChatAccountQRCodeImageURL"] = "" config.OptionMap["TurnstileSiteKey"] = ""
common.OptionMap["TurnstileSiteKey"] = "" config.OptionMap["TurnstileSecretKey"] = ""
common.OptionMap["TurnstileSecretKey"] = "" config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10)
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10)
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink config.OptionMap["TopUpLink"] = config.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink config.OptionMap["ChatLink"] = config.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes)
common.OptionMapRWMutex.Unlock() config.OptionMap["Theme"] = config.Theme
config.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
} }
func loadOptionsFromDatabase() { func loadOptionsFromDatabase() {
options, _ := AllOption() options, _ := AllOption()
for _, option := range options { for _, option := range options {
if option.Key == "ModelRatio" {
option.Value = common.AddNewMissingRatio(option.Value)
}
err := updateOptionMap(option.Key, option.Value) err := updateOptionMap(option.Key, option.Value)
if err != nil { if err != nil {
common.SysError("failed to update option map: " + err.Error()) logger.SysError("failed to update option map: " + err.Error())
} }
} }
} }
@@ -88,7 +94,7 @@ func loadOptionsFromDatabase() {
func SyncOptions(frequency int) { func SyncOptions(frequency int) {
for { for {
time.Sleep(time.Duration(frequency) * time.Second) time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing options from database") logger.SysLog("syncing options from database")
loadOptionsFromDatabase() loadOptionsFromDatabase()
} }
} }
@@ -110,113 +116,110 @@ func UpdateOption(key string, value string) error {
} }
func updateOptionMap(key string, value string) (err error) { func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock() config.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock() defer config.OptionMapRWMutex.Unlock()
common.OptionMap[key] = value config.OptionMap[key] = value
if strings.HasSuffix(key, "Permission") {
intValue, _ := strconv.Atoi(value)
switch key {
case "FileUploadPermission":
common.FileUploadPermission = intValue
case "FileDownloadPermission":
common.FileDownloadPermission = intValue
case "ImageUploadPermission":
common.ImageUploadPermission = intValue
case "ImageDownloadPermission":
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") { if strings.HasSuffix(key, "Enabled") {
boolValue := value == "true" boolValue := value == "true"
switch key { switch key {
case "PasswordRegisterEnabled": case "PasswordRegisterEnabled":
common.PasswordRegisterEnabled = boolValue config.PasswordRegisterEnabled = boolValue
case "PasswordLoginEnabled": case "PasswordLoginEnabled":
common.PasswordLoginEnabled = boolValue config.PasswordLoginEnabled = boolValue
case "EmailVerificationEnabled": case "EmailVerificationEnabled":
common.EmailVerificationEnabled = boolValue config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue config.GitHubOAuthEnabled = boolValue
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled": case "TurnstileCheckEnabled":
common.TurnstileCheckEnabled = boolValue config.TurnstileCheckEnabled = boolValue
case "RegisterEnabled": case "RegisterEnabled":
common.RegisterEnabled = boolValue config.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue config.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled": case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue config.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
config.AutomaticEnableChannelEnabled = boolValue
case "ApproximateTokenEnabled": case "ApproximateTokenEnabled":
common.ApproximateTokenEnabled = boolValue config.ApproximateTokenEnabled = boolValue
case "LogConsumeEnabled": case "LogConsumeEnabled":
common.LogConsumeEnabled = boolValue config.LogConsumeEnabled = boolValue
case "DisplayInCurrencyEnabled": case "DisplayInCurrencyEnabled":
common.DisplayInCurrencyEnabled = boolValue config.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled": case "DisplayTokenStatEnabled":
common.DisplayTokenStatEnabled = boolValue config.DisplayTokenStatEnabled = boolValue
} }
} }
switch key { switch key {
case "EmailDomainWhitelist": case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",") config.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer": case "SMTPServer":
common.SMTPServer = value config.SMTPServer = value
case "SMTPPort": case "SMTPPort":
intValue, _ := strconv.Atoi(value) intValue, _ := strconv.Atoi(value)
common.SMTPPort = intValue config.SMTPPort = intValue
case "SMTPAccount": case "SMTPAccount":
common.SMTPAccount = value config.SMTPAccount = value
case "SMTPFrom": case "SMTPFrom":
common.SMTPFrom = value config.SMTPFrom = value
case "SMTPToken": case "SMTPToken":
common.SMTPToken = value config.SMTPToken = value
case "ServerAddress": case "ServerAddress":
common.ServerAddress = value config.ServerAddress = value
case "GitHubClientId": case "GitHubClientId":
common.GitHubClientId = value config.GitHubClientId = value
case "GitHubClientSecret": case "GitHubClientSecret":
common.GitHubClientSecret = value config.GitHubClientSecret = value
case "Footer": case "Footer":
common.Footer = value config.Footer = value
case "SystemName": case "SystemName":
common.SystemName = value config.SystemName = value
case "Logo": case "Logo":
common.Logo = value config.Logo = value
case "WeChatServerAddress": case "WeChatServerAddress":
common.WeChatServerAddress = value config.WeChatServerAddress = value
case "WeChatServerToken": case "WeChatServerToken":
common.WeChatServerToken = value config.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL": case "WeChatAccountQRCodeImageURL":
common.WeChatAccountQRCodeImageURL = value config.WeChatAccountQRCodeImageURL = value
case "MessagePusherAddress":
config.MessagePusherAddress = value
case "MessagePusherToken":
config.MessagePusherToken = value
case "TurnstileSiteKey": case "TurnstileSiteKey":
common.TurnstileSiteKey = value config.TurnstileSiteKey = value
case "TurnstileSecretKey": case "TurnstileSecretKey":
common.TurnstileSecretKey = value config.TurnstileSecretKey = value
case "QuotaForNewUser": case "QuotaForNewUser":
common.QuotaForNewUser, _ = strconv.Atoi(value) config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64)
case "QuotaForInviter": case "QuotaForInviter":
common.QuotaForInviter, _ = strconv.Atoi(value) config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64)
case "QuotaForInvitee": case "QuotaForInvitee":
common.QuotaForInvitee, _ = strconv.Atoi(value) config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64)
case "QuotaRemindThreshold": case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value) config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64)
case "PreConsumedQuota": case "PreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value) config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64)
case "RetryTimes": case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value) config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio": case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = common.UpdateGroupRatioByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "TopUpLink": case "TopUpLink":
common.TopUpLink = value config.TopUpLink = value
case "ChatLink": case "ChatLink":
common.ChatLink = value config.ChatLink = value
case "ChannelDisableThreshold": case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit": case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "Theme":
config.Theme = value
} }
return err return err
} }

View File

@@ -3,8 +3,9 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common"
) )
type Redemption struct { type Redemption struct {
@@ -13,7 +14,7 @@ type Redemption struct {
Key string `json:"key" gorm:"type:char(32);uniqueIndex"` Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"` Name string `json:"name" gorm:"index"`
Quota int `json:"quota" gorm:"default:100"` Quota int64 `json:"quota" gorm:"bigint;default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"` CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"` RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request Count int `json:"count" gorm:"-:all"` // only for api request
@@ -41,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
return &redemption, err return &redemption, err
} }
func Redeem(key string, userId int) (quota int, err error) { func Redeem(key string, userId int) (quota int64, err error) {
if key == "" { if key == "" {
return 0, errors.New("未提供兑换码") return 0, errors.New("未提供兑换码")
} }
@@ -67,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil { if err != nil {
return err return err
} }
redemption.RedeemedTime = common.GetTimestamp() redemption.RedeemedTime = helper.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed redemption.Status = common.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error err = tx.Save(redemption).Error
return err return err

View File

@@ -3,8 +3,12 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common"
) )
type Token struct { type Token struct {
@@ -16,15 +20,26 @@ type Token struct {
CreatedTime int64 `json:"created_time" gorm:"bigint"` CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"` AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"` RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
} }
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
var tokens []*Token var tokens []*Token
var err error var err error
err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error query := DB.Where("user_id = ?", userId)
switch order {
case "remain_quota":
query = query.Order("unlimited_quota desc, remain_quota desc")
case "used_quota":
query = query.Order("used_quota desc")
default:
query = query.Order("id desc")
}
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
return tokens, err return tokens, err
} }
@@ -38,7 +53,13 @@ func ValidateUserToken(key string) (token *Token, err error) {
return nil, errors.New("未提供令牌") return nil, errors.New("未提供令牌")
} }
token, err = CacheGetTokenByKey(key) token, err = CacheGetTokenByKey(key)
if err == nil { if err != nil {
logger.SysError("CacheGetTokenByKey failed: " + err.Error())
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("无效的令牌")
}
return nil, errors.New("令牌验证失败")
}
if token.Status == common.TokenStatusExhausted { if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽") return nil, errors.New("该令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired { } else if token.Status == common.TokenStatusExpired {
@@ -47,12 +68,12 @@ func ValidateUserToken(key string) (token *Token, err error) {
if token.Status != common.TokenStatusEnabled { if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用") return nil, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled { if !common.RedisEnabled {
token.Status = common.TokenStatusExpired token.Status = common.TokenStatusExpired
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
common.SysError("failed to update token status" + err.Error()) logger.SysError("failed to update token status" + err.Error())
} }
} }
return nil, errors.New("该令牌已过期") return nil, errors.New("该令牌已过期")
@@ -63,15 +84,13 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExhausted token.Status = common.TokenStatusExhausted
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
common.SysError("failed to update token status" + err.Error()) logger.SysError("failed to update token status" + err.Error())
} }
} }
return nil, errors.New("该令牌额度已用尽") return nil, errors.New("该令牌额度已用尽")
} }
return token, nil return token, nil
} }
return nil, errors.New("无效的令牌")
}
func GetTokenByIds(id int, userId int) (*Token, error) { func GetTokenByIds(id int, userId int) (*Token, error) {
if id == 0 || userId == 0 { if id == 0 || userId == 0 {
@@ -130,51 +149,51 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete() return token.Delete()
} }
func IncreaseTokenQuota(id int, quota int) (err error) { func IncreaseTokenQuota(id int, quota int64) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota) addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil return nil
} }
return increaseTokenQuota(id, quota) return increaseTokenQuota(id, quota)
} }
func increaseTokenQuota(id int, quota int) (err error) { func increaseTokenQuota(id int, quota int64) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates( err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota), "remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota),
"accessed_time": common.GetTimestamp(), "accessed_time": helper.GetTimestamp(),
}, },
).Error ).Error
return err return err
} }
func DecreaseTokenQuota(id int, quota int) (err error) { func DecreaseTokenQuota(id int, quota int64) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil return nil
} }
return decreaseTokenQuota(id, quota) return decreaseTokenQuota(id, quota)
} }
func decreaseTokenQuota(id int, quota int) (err error) { func decreaseTokenQuota(id int, quota int64) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates( err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota), "remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota),
"accessed_time": common.GetTimestamp(), "accessed_time": helper.GetTimestamp(),
}, },
).Error ).Error
return err return err
} }
func PreConsumeTokenQuota(tokenId int, quota int) (err error) { func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
@@ -192,24 +211,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if userQuota < quota { if userQuota < quota {
return errors.New("用户额度不足") return errors.New("用户额度不足")
} }
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold
noMoreQuota := userQuota-quota <= 0 noMoreQuota := userQuota-quota <= 0
if quotaTooLow || noMoreQuota { if quotaTooLow || noMoreQuota {
go func() { go func() {
email, err := GetUserEmail(token.UserId) email, err := GetUserEmail(token.UserId)
if err != nil { if err != nil {
common.SysError("failed to fetch user email: " + err.Error()) logger.SysError("failed to fetch user email: " + err.Error())
} }
prompt := "您的额度即将用尽" prompt := "您的额度即将用尽"
if noMoreQuota { if noMoreQuota {
prompt = "您的额度已用尽" prompt = "您的额度已用尽"
} }
if email != "" { if email != "" {
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress)
err = common.SendEmail(prompt, email, err = message.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil { if err != nil {
common.SysError("failed to send email" + err.Error()) logger.SysError("failed to send email" + err.Error())
} }
} }
}() }()
@@ -224,7 +243,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
return err return err
} }
func PostConsumeTokenQuota(tokenId int, quota int) (err error) { func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
token, err := GetTokenById(tokenId) token, err := GetTokenById(tokenId)
if quota > 0 { if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota) err = DecreaseUserQuota(token.UserId, quota)

View File

@@ -3,8 +3,12 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common"
"strings" "strings"
) )
@@ -15,15 +19,15 @@ type User struct {
Username string `json:"username" gorm:"unique;index" validate:"max=12"` Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common Role int `json:"role" gorm:"type:int;default:1"` // admin, util
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"` Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"` GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"` Quota int64 `json:"quota" gorm:"bigint;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
@@ -36,13 +40,30 @@ func GetMaxUserId() int {
return user.Id return user.Id
} }
func GetAllUsers(startIdx int, num int) (users []*User, err error) { func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
switch order {
case "quota":
query = query.Order("quota desc")
case "used_quota":
query = query.Order("used_quota desc")
case "request_count":
query = query.Order("request_count desc")
default:
query = query.Order("id desc")
}
err = query.Find(&users).Error
return users, err return users, err
} }
func SearchUsers(keyword string) (users []*User, err error) { func SearchUsers(keyword string) (users []*User, err error) {
if !common.UsingPostgreSQL {
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
} else {
err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
}
return users, err return users, err
} }
@@ -85,24 +106,24 @@ func (user *User) Insert(inviterId int) error {
return err return err
} }
} }
user.Quota = common.QuotaForNewUser user.Quota = config.QuotaForNewUser
user.AccessToken = common.GetUUID() user.AccessToken = helper.GetUUID()
user.AffCode = common.GetRandomString(4) user.AffCode = helper.GetRandomString(4)
result := DB.Create(user) result := DB.Create(user)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
if common.QuotaForNewUser > 0 { if config.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
} }
if inviterId != 0 { if inviterId != 0 {
if common.QuotaForInvitee > 0 { if config.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
} }
if common.QuotaForInviter > 0 { if config.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) _ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
} }
} }
return nil return nil
@@ -116,6 +137,11 @@ func (user *User) Update(updatePassword bool) error {
return err return err
} }
} }
if user.Status == common.UserStatusDisabled {
blacklist.BanUser(user.Id)
} else if user.Status == common.UserStatusEnabled {
blacklist.UnbanUser(user.Id)
}
err = DB.Model(user).Updates(user).Error err = DB.Model(user).Updates(user).Error
return err return err
} }
@@ -124,7 +150,10 @@ func (user *User) Delete() error {
if user.Id == 0 { if user.Id == 0 {
return errors.New("id 为空!") return errors.New("id 为空!")
} }
err := DB.Delete(user).Error blacklist.BanUser(user.Id)
user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID())
user.Status = common.UserStatusDeleted
err := DB.Model(user).Updates(user).Error
return err return err
} }
@@ -137,7 +166,15 @@ func (user *User) ValidateAndFill() (err error) {
if user.Username == "" || password == "" { if user.Username == "" || password == "" {
return errors.New("用户名或密码为空") return errors.New("用户名或密码为空")
} }
DB.Where(User{Username: user.Username}).First(user) err = DB.Where("username = ?", user.Username).First(user).Error
if err != nil {
// we must make sure check username firstly
// consider this case: a malicious user set his username as other's email
err := DB.Where("email = ?", user.Username).First(user).Error
if err != nil {
return errors.New("用户名或密码错误,或用户已被封禁")
}
}
okay := common.ValidatePasswordAndHash(password, user.Password) okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled { if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁") return errors.New("用户名或密码错误,或用户已被封禁")
@@ -220,7 +257,7 @@ func IsAdmin(userId int) bool {
var user User var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil { if err != nil {
common.SysError("no such user " + err.Error()) logger.SysError("no such user " + err.Error())
return false return false
} }
return user.Role >= common.RoleAdminUser return user.Role >= common.RoleAdminUser
@@ -250,12 +287,12 @@ func ValidateAccessToken(token string) (user *User) {
return nil return nil
} }
func GetUserQuota(id int) (quota int, err error) { func GetUserQuota(id int) (quota int64, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
return quota, err return quota, err
} }
func GetUserUsedQuota(id int) (quota int, err error) { func GetUserUsedQuota(id int) (quota int64, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find(&quota).Error err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find(&quota).Error
return quota, err return quota, err
} }
@@ -275,34 +312,34 @@ func GetUserGroup(id int) (group string, err error) {
return group, err return group, err
} }
func IncreaseUserQuota(id int, quota int) (err error) { func IncreaseUserQuota(id int, quota int64) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota) addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil return nil
} }
return increaseUserQuota(id, quota) return increaseUserQuota(id, quota)
} }
func increaseUserQuota(id int, quota int) (err error) { func increaseUserQuota(id int, quota int64) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
return err return err
} }
func DecreaseUserQuota(id int, quota int) (err error) { func DecreaseUserQuota(id int, quota int64) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota) addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil return nil
} }
return decreaseUserQuota(id, quota) return decreaseUserQuota(id, quota)
} }
func decreaseUserQuota(id int, quota int) (err error) { func decreaseUserQuota(id int, quota int64) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err return err
} }
@@ -312,8 +349,8 @@ func GetRootUserEmail() (email string) {
return email return email
} }
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) {
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1) addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return return
@@ -321,7 +358,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
updateUserUsedQuotaAndRequestCount(id, quota, 1) updateUserUsedQuotaAndRequestCount(id, quota, 1)
} }
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates( err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{ map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota),
@@ -329,25 +366,25 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
}, },
).Error ).Error
if err != nil { if err != nil {
common.SysError("failed to update user used quota and request count: " + err.Error()) logger.SysError("failed to update user used quota and request count: " + err.Error())
} }
} }
func updateUserUsedQuota(id int, quota int) { func updateUserUsedQuota(id int, quota int64) {
err := DB.Model(&User{}).Where("id = ?", id).Updates( err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{ map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota),
}, },
).Error ).Error
if err != nil { if err != nil {
common.SysError("failed to update user used quota: " + err.Error()) logger.SysError("failed to update user used quota: " + err.Error())
} }
} }
func updateUserRequestCount(id int, count int) { func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil { if err != nil {
common.SysError("failed to update user request count: " + err.Error()) logger.SysError("failed to update user request count: " + err.Error())
} }
} }

View File

@@ -1,7 +1,8 @@
package model package model
import ( import (
"one-api/common" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"sync" "sync"
"time" "time"
) )
@@ -15,12 +16,12 @@ const (
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
) )
var batchUpdateStores []map[int]int var batchUpdateStores []map[int]int64
var batchUpdateLocks []sync.Mutex var batchUpdateLocks []sync.Mutex
func init() { func init() {
for i := 0; i < BatchUpdateTypeCount; i++ { for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateStores = append(batchUpdateStores, make(map[int]int)) batchUpdateStores = append(batchUpdateStores, make(map[int]int64))
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
} }
} }
@@ -28,13 +29,13 @@ func init() {
func InitBatchUpdater() { func InitBatchUpdater() {
go func() { go func() {
for { for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
batchUpdate() batchUpdate()
} }
}() }()
} }
func addNewRecord(type_ int, id int, value int) { func addNewRecord(type_ int, id int, value int64) {
batchUpdateLocks[type_].Lock() batchUpdateLocks[type_].Lock()
defer batchUpdateLocks[type_].Unlock() defer batchUpdateLocks[type_].Unlock()
if _, ok := batchUpdateStores[type_][id]; !ok { if _, ok := batchUpdateStores[type_][id]; !ok {
@@ -45,11 +46,11 @@ func addNewRecord(type_ int, id int, value int) {
} }
func batchUpdate() { func batchUpdate() {
common.SysLog("batch update started") logger.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ { for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock() batchUpdateLocks[i].Lock()
store := batchUpdateStores[i] store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int) batchUpdateStores[i] = make(map[int]int64)
batchUpdateLocks[i].Unlock() batchUpdateLocks[i].Unlock()
// TODO: maybe we can combine updates with same key? // TODO: maybe we can combine updates with same key?
for key, value := range store { for key, value := range store {
@@ -57,21 +58,21 @@ func batchUpdate() {
case BatchUpdateTypeUserQuota: case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value) err := increaseUserQuota(key, value)
if err != nil { if err != nil {
common.SysError("failed to batch update user quota: " + err.Error()) logger.SysError("failed to batch update user quota: " + err.Error())
} }
case BatchUpdateTypeTokenQuota: case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value) err := increaseTokenQuota(key, value)
if err != nil { if err != nil {
common.SysError("failed to batch update token quota: " + err.Error()) logger.SysError("failed to batch update token quota: " + err.Error())
} }
case BatchUpdateTypeUsedQuota: case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value) updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount: case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value) updateUserRequestCount(key, int(value))
case BatchUpdateTypeChannelUsedQuota: case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value) updateChannelUsedQuota(key, value)
} }
} }
} }
common.SysLog("batch update finished") logger.SysLog("batch update finished")
} }

55
monitor/channel.go Normal file
View File

@@ -0,0 +1,55 @@
package monitor
import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/model"
)
func notifyRootUser(subject string, content string) {
if config.MessagePusherAddress != "" {
err := message.SendMessage(subject, content, content)
if err != nil {
logger.SysError(fmt.Sprintf("failed to send message: %s", err.Error()))
} else {
return
}
}
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
}
err := message.SendEmail(subject, config.RootUserEmail, content)
if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// DisableChannel disable & notify
func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
subject := fmt.Sprintf("渠道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
func MetricDisableChannel(channelId int, successRate float64) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道(#%d在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100)
notifyRootUser(subject, content)
}
// EnableChannel enable & notify
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
subject := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}

79
monitor/metric.go Normal file
View File

@@ -0,0 +1,79 @@
package monitor
import (
"github.com/songquanpeng/one-api/common/config"
)
var store = make(map[int][]bool)
var metricSuccessChan = make(chan int, config.MetricSuccessChanSize)
var metricFailChan = make(chan int, config.MetricFailChanSize)
func consumeSuccess(channelId int) {
if len(store[channelId]) > config.MetricQueueSize {
store[channelId] = store[channelId][1:]
}
store[channelId] = append(store[channelId], true)
}
func consumeFail(channelId int) (bool, float64) {
if len(store[channelId]) > config.MetricQueueSize {
store[channelId] = store[channelId][1:]
}
store[channelId] = append(store[channelId], false)
successCount := 0
for _, success := range store[channelId] {
if success {
successCount++
}
}
successRate := float64(successCount) / float64(len(store[channelId]))
if len(store[channelId]) < config.MetricQueueSize {
return false, successRate
}
if successRate < config.MetricSuccessRateThreshold {
store[channelId] = make([]bool, 0)
return true, successRate
}
return false, successRate
}
func metricSuccessConsumer() {
for {
select {
case channelId := <-metricSuccessChan:
consumeSuccess(channelId)
}
}
}
func metricFailConsumer() {
for {
select {
case channelId := <-metricFailChan:
disable, successRate := consumeFail(channelId)
if disable {
go MetricDisableChannel(channelId, successRate)
}
}
}
}
func init() {
if config.EnableMetric {
go metricSuccessConsumer()
go metricFailConsumer()
}
}
func Emit(channelId int, success bool) {
if !config.EnableMetric {
return
}
go func() {
if success {
metricSuccessChan <- channelId
} else {
metricFailChan <- channelId
}
}()
}

10
pull_request_template.md Normal file
View File

@@ -0,0 +1,10 @@
[//]: # (请按照以下格式关联 issue)
[//]: # (请在提交 PR 前确认所提交的功能可用,需要附上截图,谢谢)
[//]: # (项目维护者一般仅在周末处理 PR因此如若未能及时回复希望能理解)
[//]: # (开发者交流群910657413)
[//]: # (请在提交 PR 之前删除上面的注释)
close #issue_number
我已确认该 PR 已自测通过,相关截图如下:
(此处放上测试通过的截图,如果不涉及前端改动或从 UI 上无法看出,请放终端启动成功的截图)

View File

@@ -0,0 +1,8 @@
package ai360
var ModelList = []string{
"360GPT_S2_V9",
"embedding-bert-512-v1",
"embedding_s1_v1",
"semantic_similarity_s1_v1",
}

View File

@@ -0,0 +1,60 @@
package aiproxy
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
return aiProxyLibraryRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "aiproxy"
}

View File

@@ -0,0 +1,9 @@
package aiproxy
import "github.com/songquanpeng/one-api/relay/channel/openai"
var ModelList = []string{""}
func init() {
ModelList = openai.ModelList
}

View File

@@ -1,63 +1,37 @@
package controller package aiproxy
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"
"one-api/common"
"strconv" "strconv"
"strings" "strings"
) )
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答 // https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
type AIProxyLibraryRequest struct { func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type AIProxyLibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type AIProxyLibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type AIProxyLibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []AIProxyLibraryDocument `json:"documents"`
AIProxyLibraryError
}
type AIProxyLibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []AIProxyLibraryDocument `json:"documents"`
}
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := "" query := ""
if len(request.Messages) != 0 { if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].Content query = request.Messages[len(request.Messages)-1].StringContent()
} }
return &AIProxyLibraryRequest{ return &LibraryRequest{
Model: request.Model, Model: request.Model,
Stream: request.Stream, Stream: request.Stream,
Query: query, Query: query,
} }
} }
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { func aiProxyDocuments2Markdown(documents []LibraryDocument) string {
if len(documents) == 0 { if len(documents) == 0 {
return "" return ""
} }
@@ -68,52 +42,52 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
return content return content
} }
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse {
content := response.Answer + aiProxyDocuments2Markdown(response.Documents) content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
choice := OpenAITextResponseChoice{ choice := openai.TextResponseChoice{
Index: 0, Index: 0,
Message: Message{ Message: model.Message{
Role: "assistant", Role: "assistant",
Content: content, Content: content,
}, },
FinishReason: "stop", FinishReason: "stop",
} }
fullTextResponse := OpenAITextResponse{ fullTextResponse := openai.TextResponse{
Id: common.GetUUID(), Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
} }
return &fullTextResponse return &fullTextResponse
} }
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse { func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents) choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &stopFinishReason choice.FinishReason = &constant.StopFinishReason
return &ChatCompletionsStreamResponse{ return &openai.ChatCompletionsStreamResponse{
Id: common.GetUUID(), Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "", Model: "",
Choices: []ChatCompletionsStreamResponseChoice{choice}, Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
} }
} }
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse { func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content choice.Delta.Content = response.Content
return &ChatCompletionsStreamResponse{ return &openai.ChatCompletionsStreamResponse{
Id: common.GetUUID(), Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Model: response.Model, Model: response.Model,
Choices: []ChatCompletionsStreamResponseChoice{choice}, Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
} }
} }
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage Usage var usage model.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@@ -143,15 +117,15 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
} }
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c) common.SetEventStreamHeaders(c)
var documents []AIProxyLibraryDocument var documents []LibraryDocument
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
var AIProxyLibraryResponse AIProxyLibraryStreamResponse var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse) err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true return true
} }
if len(AIProxyLibraryResponse.Documents) != 0 { if len(AIProxyLibraryResponse.Documents) != 0 {
@@ -160,7 +134,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())
return true return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -169,7 +143,7 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
response := documentsAIProxyLibrary(documents) response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())
return true return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -179,28 +153,28 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, &usage return nil, &usage
} }
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var AIProxyLibraryResponse AIProxyLibraryResponse var AIProxyLibraryResponse LibraryResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse) err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if AIProxyLibraryResponse.ErrCode != 0 { if AIProxyLibraryResponse.ErrCode != 0 {
return &OpenAIErrorWithStatusCode{ return &model.ErrorWithStatusCode{
OpenAIError: OpenAIError{ Error: model.Error{
Message: AIProxyLibraryResponse.Message, Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode), Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode, Code: AIProxyLibraryResponse.ErrCode,
@@ -211,10 +185,13 @@ func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
if err != nil {
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }

View File

@@ -0,0 +1,32 @@
package aiproxy
type LibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type LibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type LibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type LibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []LibraryDocument `json:"documents"`
LibraryError
}
type LibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []LibraryDocument `json:"documents"`
}

View File

@@ -0,0 +1,86 @@
package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
if meta.IsStream {
req.Header.Set("Accept", "text/event-stream")
}
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.IsStream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString(common.ConfigKeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "ali"
}

View File

@@ -0,0 +1,6 @@
package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
}

267
relay/channel/ali/main.go Normal file
View File

@@ -0,0 +1,267 @@
package ali
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
const EnableSearchModelSuffix = "-internet"
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, Message{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
if request.TopP >= 1 {
request.TopP = 0.9999
}
return &ChatRequest{
Model: aliModel,
Input: Input{
Messages: messages,
},
Parameters: Parameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
TopK: request.TopK,
ResultFormat: "message",
Tools: request.Tools,
},
}
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: response.Output.Choices,
Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
if len(aliResponse.Output.Choices) == 0 {
return nil
}
aliChoice := aliResponse.Output.Choices[0]
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta = aliChoice.Message
if aliChoice.FinishReason != "null" {
finishReason := aliChoice.FinishReason
choice.FinishReason = &finishReason
}
response := openai.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "qwen",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
//lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse ChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
return true
}
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := c.Request.Context()
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
logger.Debugf(ctx, "response body: %s\n", responseBody)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
fullTextResponse.Model = "qwen"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -0,0 +1,81 @@
package ali
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
}
type Input struct {
//Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type Embedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type EmbeddingResponse struct {
Output struct {
Embeddings []Embedding `json:"embeddings"`
} `json:"output"`
Usage Usage `json:"usage"`
Error
}
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Output struct {
//Text string `json:"text"`
//FinishReason string `json:"finish_reason"`
Choices []openai.TextResponseChoice `json:"choices"`
}
type ChatResponse struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
Error
}

View File

@@ -0,0 +1,63 @@
package anthropic
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-api-key", meta.APIKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
req.Header.Set("anthropic-beta", "messages-2023-12-15")
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "anthropic"
}

View File

@@ -0,0 +1,8 @@
package anthropic
var ModelList = []string{
"claude-instant-1.2", "claude-2.0", "claude-2.1",
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
}

View File

@@ -0,0 +1,273 @@
package anthropic
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
func stopReasonClaude2OpenAI(reason *string) string {
if reason == nil {
return ""
}
switch *reason {
case "end_turn":
return "stop"
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return *reason
}
}
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeRequest := Request{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
// legacy model name mapping
if claudeRequest.Model == "claude-instant-1" {
claudeRequest.Model = "claude-instant-1.1"
} else if claudeRequest.Model == "claude-2" {
claudeRequest.Model = "claude-2.1"
}
for _, message := range textRequest.Messages {
if message.Role == "system" && claudeRequest.System == "" {
claudeRequest.System = message.StringContent()
continue
}
claudeMessage := Message{
Role: message.Role,
}
var content Content
if message.IsStringContent() {
content.Type = "text"
content.Text = message.StringContent()
claudeMessage.Content = append(claudeMessage.Content, content)
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
continue
}
var contents []Content
openaiContent := message.ParseContent()
for _, part := range openaiContent {
var content Content
if part.Type == model.ContentTypeText {
content.Type = "text"
content.Text = part.Text
} else if part.Type == model.ContentTypeImageURL {
content.Type = "image"
content.Source = &ImageSource{
Type: "base64",
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
content.Source.MediaType = mimeType
content.Source.Data = data
}
contents = append(contents, content)
}
claudeMessage.Content = contents
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
}
return &claudeRequest
}
// https://docs.anthropic.com/claude/reference/messages-streaming
func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response
var responseText string
var stopReason string
switch claudeResponse.Type {
case "message_start":
return nil, claudeResponse.Message
case "content_block_start":
if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text
}
case "content_block_delta":
if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text
}
case "message_delta":
if claudeResponse.Usage != nil {
response = &Response{
Usage: *claudeResponse.Usage,
}
}
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
stopReason = *claudeResponse.Delta.StopReason
}
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
choice.Delta.Role = "assistant"
finishReason := stopReasonClaude2OpenAI(&stopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var openaiResponse openai.ChatCompletionsStreamResponse
openaiResponse.Object = "chat.completion.chunk"
openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &openaiResponse, response
}
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: responseText,
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
Model: claudeResponse.Model,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 {
continue
}
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var usage model.Usage
var modelName string
var id string
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse StreamResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, meta := streamResponseClaude2OpenAI(&claudeResponse)
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
modelName = meta.Model
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
}
if response == nil {
return true
}
response.Id = id
response.Model = modelName
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
_ = resp.Body.Close()
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse Response
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = modelName
usage := model.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,75 @@
package anthropic
// https://docs.anthropic.com/claude/reference/messages_post
type Metadata struct {
UserId string `json:"user_id"`
}
type ImageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
}
type Content struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ImageSource `json:"source,omitempty"`
}
type Message struct {
Role string `json:"role"`
Content []Content `json:"content"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//Metadata `json:"metadata,omitempty"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type Response struct {
Id string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Content []Content `json:"content"`
Model string `json:"model"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Usage Usage `json:"usage"`
Error Error `json:"error"`
}
type Delta struct {
Type string `json:"type"`
Text string `json:"text"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
}
type StreamResponse struct {
Type string `json:"type"`
Message *Response `json:"message"`
Index int `json:"index"`
ContentBlock *Content `json:"content_block"`
Delta *Delta `json:"delta"`
Usage *Usage `json:"usage"`
}

View File

@@ -0,0 +1,7 @@
package baichuan
var ModelList = []string{
"Baichuan2-Turbo",
"Baichuan2-Turbo-192k",
"Baichuan-Text-Embedding",
}

View File

@@ -0,0 +1,118 @@
package baidu
import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(meta.ActualModelName, "Embedding") {
suffix = "embeddings/"
}
if strings.HasPrefix(meta.ActualModelName, "bge-large") {
suffix = "embeddings/"
}
if strings.HasPrefix(meta.ActualModelName, "tao-8k") {
suffix = "embeddings/"
}
switch meta.ActualModelName {
case "ERNIE-4.0":
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
suffix += "embedding-v1"
case "bge-large-zh":
suffix += "bge_large_zh"
case "bge-large-en":
suffix += "bge_large_en"
case "tao-8k":
suffix += "tao_8k"
default:
suffix += meta.ActualModelName
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
var accessToken string
var err error
if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
return "", err
}
fullRequestURL += "?access_token=" + accessToken
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "baidu"
}

View File

@@ -0,0 +1,13 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",
"tao-8k",
}

View File

@@ -1,4 +1,4 @@
package controller package baidu
import ( import (
"bufio" "bufio"
@@ -6,9 +6,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
"one-api/common"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -16,148 +21,111 @@ import (
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
type BaiduTokenResponse struct { type TokenResponse struct {
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
} }
type BaiduMessage struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
} }
type BaiduChatRequest struct { type ChatRequest struct {
Messages []BaiduMessage `json:"messages"` Messages []Message `json:"messages"`
Stream bool `json:"stream"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"` UserId string `json:"user_id,omitempty"`
} }
type BaiduError struct { type Error struct {
ErrorCode int `json:"error_code"` ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"` ErrorMsg string `json:"error_msg"`
} }
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
var baiduTokenStore sync.Map var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages)) baiduRequest := ChatRequest{
Messages: make([]Message, 0, len(request.Messages)),
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
MaxOutputTokens: request.MaxTokens,
UserId: request.User,
}
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, BaiduMessage{ baiduRequest.System = message.StringContent()
Role: "user",
Content: message.Content,
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Content: "Okay",
})
} else { } else {
messages = append(messages, BaiduMessage{ baiduRequest.Messages = append(baiduRequest.Messages, Message{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.StringContent(),
}) })
} }
} }
return &BaiduChatRequest{ return &baiduRequest
Messages: messages,
Stream: request.Stream,
}
} }
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := OpenAITextResponseChoice{ choice := openai.TextResponseChoice{
Index: 0, Index: 0,
Message: Message{ Message: model.Message{
Role: "assistant", Role: "assistant",
Content: response.Result, Content: response.Result,
}, },
FinishReason: "stop", FinishReason: "stop",
} }
fullTextResponse := OpenAITextResponse{ fullTextResponse := openai.TextResponse{
Id: response.Id, Id: response.Id,
Object: "chat.completion", Object: "chat.completion",
Created: response.Created, Created: response.Created,
Choices: []OpenAITextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
Usage: response.Usage, Usage: response.Usage,
} }
return &fullTextResponse return &fullTextResponse
} }
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd { if baiduResponse.IsEnd {
choice.FinishReason = &stopFinishReason choice.FinishReason = &constant.StopFinishReason
} }
response := ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: baiduResponse.Id, Id: baiduResponse.Id,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: baiduResponse.Created, Created: baiduResponse.Created,
Model: "ernie-bot", Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice}, Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
} }
return &response return &response
} }
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &BaiduEmbeddingRequest{ return &EmbeddingRequest{
Input: request.ParseInput(), Input: request.ParseInput(),
} }
} }
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{ openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list", Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding", Model: "baidu-embedding",
Usage: response.Usage, Usage: response.Usage,
} }
for _, item := range response.Data { for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: item.Object, Object: item.Object,
Index: item.Index, Index: item.Index,
Embedding: item.Embedding, Embedding: item.Embedding,
@@ -166,8 +134,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
return &openAIEmbeddingResponse return &openAIEmbeddingResponse
} }
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage Usage var usage model.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@@ -194,14 +162,14 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
} }
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c) common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
var baiduResponse BaiduChatStreamResponse var baiduResponse ChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse) err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
return true return true
} }
if baiduResponse.Usage.TotalTokens != 0 { if baiduResponse.Usage.TotalTokens != 0 {
@@ -212,7 +180,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
response := streamResponseBaidu2OpenAI(&baiduResponse) response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())
return true return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -224,28 +192,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, &usage return nil, &usage
} }
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse BaiduChatResponse var baiduResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &baiduResponse) err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if baiduResponse.ErrorMsg != "" { if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{ return &model.ErrorWithStatusCode{
OpenAIError: OpenAIError{ Error: model.Error{
Message: baiduResponse.ErrorMsg, Message: baiduResponse.ErrorMsg,
Type: "baidu_error", Type: "baidu_error",
Param: "", Param: "",
@@ -255,9 +223,10 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
}, nil }, nil
} }
fullTextResponse := responseBaidu2OpenAI(&baiduResponse) fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
fullTextResponse.Model = "ernie-bot"
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
@@ -265,23 +234,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse BaiduEmbeddingResponse var baiduResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &baiduResponse) err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if baiduResponse.ErrorMsg != "" { if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{ return &model.ErrorWithStatusCode{
OpenAIError: OpenAIError{ Error: model.Error{
Message: baiduResponse.ErrorMsg, Message: baiduResponse.ErrorMsg,
Type: "baidu_error", Type: "baidu_error",
Param: "", Param: "",
@@ -293,7 +262,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
@@ -301,10 +270,10 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }
func getBaiduAccessToken(apiKey string) (string, error) { func GetAccessToken(apiKey string) (string, error) {
if val, ok := baiduTokenStore.Load(apiKey); ok { if val, ok := baiduTokenStore.Load(apiKey); ok {
var accessToken BaiduAccessToken var accessToken AccessToken
if accessToken, ok = val.(BaiduAccessToken); ok { if accessToken, ok = val.(AccessToken); ok {
// soon this will expire // soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go func() { go func() {
@@ -319,12 +288,12 @@ func getBaiduAccessToken(apiKey string) (string, error) {
return "", err return "", err
} }
if accessToken == nil { if accessToken == nil {
return "", errors.New("getBaiduAccessToken return a nil token") return "", errors.New("GetAccessToken return a nil token")
} }
return (*accessToken).AccessToken, nil return (*accessToken).AccessToken, nil
} }
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
parts := strings.Split(apiKey, "|") parts := strings.Split(apiKey, "|")
if len(parts) != 2 { if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey") return nil, errors.New("invalid baidu apikey")
@@ -336,13 +305,13 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/json")
res, err := impatientHTTPClient.Do(req) res, err := util.ImpatientHTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
var accessToken BaiduAccessToken var accessToken AccessToken
err = json.NewDecoder(res.Body).Decode(&accessToken) err = json.NewDecoder(res.Body).Decode(&accessToken)
if err != nil { if err != nil {
return nil, err return nil, err

Some files were not shown because too many files have changed in this diff Show More