Compare commits

...

189 Commits

Author SHA1 Message Date
wozulong
c47e1dc6fe merge upstream
Signed-off-by: wozulong <>
2024-10-14 16:31:22 +08:00
1808837298@qq.com
40baa636e4 fix: 修复自定义聊天bug
(cherry picked from commit 8d41c17ccf19cb29100dbe506d3d42a6be822ff9)
2024-10-13 00:21:52 +08:00
1808837298@qq.com
d6359ec4ff feat: 完善自定义聊天配置 2024-10-12 21:09:59 +08:00
1808837298@qq.com
89ddf83b44 feat: 弃用旧的聊天配置 2024-10-12 21:09:59 +08:00
1808837298@qq.com
6a8a4bcf65 fix: playground group 2024-10-10 13:39:09 +08:00
1808837298@qq.com
e298f2e5a4 feat: playground token name 2024-10-10 13:34:29 +08:00
1808837298@qq.com
8cea6dff4a feat: support embedding encoding_format param 2024-10-10 13:23:12 +08:00
1808837298@qq.com
5035cd054a feat: update aws claude 2024-10-09 00:42:36 +08:00
1808837298@qq.com
02c0c6501e feat: update auto disable 2024-10-08 23:15:57 +08:00
1808837298@qq.com
f0b808a41d feat: update model ratio 2024-10-03 21:12:09 +08:00
1808837298@qq.com
31d84ee32f feat: update model ratio 2024-10-03 20:48:47 +08:00
1808837298@qq.com
9969ed2d7c feat: update model ratio 2024-10-03 20:47:54 +08:00
1808837298@qq.com
746311242b fix: playground气泡溢出 #511 2024-09-27 20:49:26 +08:00
1808837298@qq.com
04a68a85dd feat: 优化playground样式 2024-09-27 20:49:25 +08:00
1808837298@qq.com
f9ba10f180 fix: playground max_tokens #512 #511 2024-09-27 20:18:53 +08:00
Calcium-Ion
334a6f8280 Update README.md 2024-09-26 01:54:33 +08:00
1808837298@qq.com
0cf53ac5ff feat: Playground相关接口禁用AccessToken 2024-09-26 01:49:35 +08:00
Calcium-Ion
af02cdc58b Merge pull request #509 from Calcium-Ion/playground
feat: playground
2024-09-26 01:00:33 +08:00
1808837298@qq.com
9a4ca1e210 feat: playground 2024-09-26 00:59:09 +08:00
wozulong
27b8495698 upgrade stripe
Signed-off-by: wozulong <>
2024-09-25 17:49:07 +08:00
1808837298@qq.com
9fe1f35fd1 fix: 第三方登录注销 #500 2024-09-25 17:15:59 +08:00
1808837298@qq.com
972ac1ee0f fix: 第三方登录注销 #500 2024-09-25 17:13:28 +08:00
1808837298@qq.com
0f95502b04 feat: 更新令牌生成算法 2024-09-25 16:31:25 +08:00
1808837298@qq.com
b58b1dc0ec feat: 更新令牌生成算法 2024-09-25 16:31:25 +08:00
1808837298@qq.com
05d9aa61df feat: 不自动生成系统访问令牌 2024-09-25 16:31:25 +08:00
1808837298@qq.com
221894d972 fix: error user role 2024-09-24 17:49:57 +08:00
wozulong
333849429b merge upstream
Signed-off-by: wozulong <>
2024-09-23 11:15:45 +08:00
1808837298@qq.com
50eab6b4e4 chore: 更新令牌分组描述 2024-09-22 19:43:06 +08:00
1808837298@qq.com
ed972eef06 feat: pricing page support multi groups #487 2024-09-22 17:44:57 +08:00
CalciumIon
c6ff785a83 feat: 无可选分组时关闭令牌分组功能 #485 2024-09-19 03:01:33 +08:00
CalciumIon
2e734e0c37 chore: 令牌分组描述歧义 2024-09-19 02:52:25 +08:00
CalciumIon
af33f36c7b feat: update gemini flash completion ratio #479 2024-09-18 20:39:06 +08:00
CalciumIon
3aa86a8cd9 feat: update gemini completion ratio #479 2024-09-18 20:37:22 +08:00
CalciumIon
af7fecbfa7 fix: 使用令牌分组时 "/v1/models" 返回模型不正确 #481 2024-09-18 19:19:37 +08:00
CalciumIon
3fbdd502b6 fix: token group #477 2024-09-18 18:55:11 +08:00
CalciumIon
052bc2075b feat: 令牌分组 2024-09-18 05:19:49 +08:00
Calcium-Ion
5f3798053f Create FUNDING.yml 2024-09-18 01:41:31 +08:00
CalciumIon
e31022c676 Update logo 2024-09-18 01:25:00 +08:00
Calcium-Ion
fff7609f06 Merge pull request #439 from guoruqiang/main
改进了聊天页面,增加了初始令牌,方便用户注册后即可使用聊天功能。
2024-09-17 23:14:19 +08:00
CalciumIon
9032b5cfbf fix: 初始令牌 2024-09-17 23:07:16 +08:00
CalciumIon
131453dac8 Update README.md 2024-09-17 23:01:34 +08:00
CalciumIon
ed948c121a Merge branch 'main' into g-main
# Conflicts:
#	web/src/App.js
2024-09-17 22:50:59 +08:00
CalciumIon
a03cd15505 fix: '/v1/models' #474 2024-09-17 22:41:54 +08:00
CalciumIon
02f5137781 fix: '/v1/models' #474 2024-09-17 22:39:58 +08:00
CalciumIon
e6df0ed20c fix: '/vi/models' #474 2024-09-17 22:36:20 +08:00
CalciumIon
f505afdc10 feat: 添加令牌ip白名单功能 2024-09-17 20:49:51 +08:00
CalciumIon
feb1d76942 feat: 优化界面显示 2024-09-17 19:55:18 +08:00
CalciumIon
6263616cd9 Update README.md 2024-09-17 03:18:12 +08:00
wozulong
9b14c5da63 merge upstream
Signed-off-by: wozulong <>
2024-09-15 17:30:59 +08:00
GuoRuqiang
6bbf1d4843 Merge branch 'Calcium-Ion:main' into main 2024-09-14 19:00:03 +08:00
1808837298@qq.com
13c993d87e feat: format o1 model max tokens param 2024-09-14 16:11:38 +08:00
CalciumIon
cb73889353 feat: support o1 channel test 2024-09-13 03:17:04 +08:00
CalciumIon
804aad3f37 feat: support o1 channel test 2024-09-13 03:15:32 +08:00
CalciumIon
3af62a3efa feat: support OpenAI o1-preview and o1-mini 2024-09-13 01:22:27 +08:00
CalciumIon
be54369c12 chore: update footer 2024-09-12 18:43:01 +08:00
CalciumIon
0cbf8e07e7 feat: support ollama multi-text embedding 2024-09-12 18:29:45 +08:00
Calcium-Ion
1675679be9 Merge pull request #464 from Yan-Zero/main
fix: tool use in claude and add gemini mapping
2024-09-12 05:04:19 +08:00
Yan
0b5f2a7089 add gemini exp 2024-09-11 19:37:03 +08:00
Yan Tau
b5bb708072 Merge branch 'Calcium-Ion:main' into main 2024-09-11 19:29:50 +08:00
CalciumIon
2650ec9b59 feat: claude response return model name 2024-09-11 19:12:55 +08:00
CalciumIon
d168a685c1 fix: cohere SafetyMode 2024-09-11 19:12:32 +08:00
wozulong
a9e3555cac update
Signed-off-by: wozulong <>
2024-09-10 17:38:30 +08:00
wozulong
8d83630d85 merge upstream
Signed-off-by: wozulong <>
2024-09-10 17:28:04 +08:00
wozulong
a60f209c85 optimized blocking issue during bulk log data deletion
Signed-off-by: wozulong <>
2024-09-10 17:24:44 +08:00
GuoRuqiang
a0d20896b3 Merge branch 'Calcium-Ion:main' into main 2024-09-08 15:56:54 +08:00
Calcium-Ion
5cab06d1ce Merge pull request #459 from HynoR/main
chore: 适配cohere的safety参数
2024-09-05 18:37:47 +08:00
CalciumIon
e3b3fdec48 feat: update chatgpt-4o token encoder 2024-09-05 18:35:34 +08:00
CalciumIon
5863aa8061 feat: remove lobe chat link #457 2024-09-05 18:34:04 +08:00
wozulong
208bc5e794 merge upstream
Signed-off-by: wozulong <>
2024-09-05 17:36:17 +08:00
Yan
0ada2371b6 fix: tool use in claude 2024-09-05 00:53:00 +08:00
CalciumIon
8bc1e956cf fix: email 2024-09-04 19:44:29 +08:00
GuoRuqiang
a0673ef2b6 Merge branch 'Calcium-Ion:main' into main 2024-09-02 21:53:54 +08:00
HynoR
416f831a6c Merge remote-tracking branch 'origin/main' 2024-09-02 06:47:58 +07:00
HynoR
0b4317ce28 Update Cohere Safety Setting 2024-09-02 06:47:49 +07:00
Calcium-Ion
12e2481acb Merge pull request #451 from Nana7mi1/main
feat: support more zhipu models
2024-09-02 01:12:10 +08:00
Calcium-Ion
270709064d Merge pull request #455 from HynoR/feat/cohere-update
Feat: 更新Cohere新模型和定价
2024-09-02 01:11:55 +08:00
CalciumIon
0830ef3305 feat: support jina embedding 2024-09-02 01:11:19 +08:00
HynoR
722cc174b7 Cohere Update 2024-09-01 15:21:05 +07:00
Nanami
97c18d0c7f feat: support more zhipu models 2024-08-31 10:20:22 +08:00
GuoRuqiang
2223aeb022 Merge branch 'Calcium-Ion:main' into main 2024-08-29 19:42:03 +08:00
CalciumIon
4b1e83c42d feat: support siliconflow embedding #447 2024-08-29 00:19:30 +08:00
GuoRuqiang
ecf2f7f212 Merge branch 'Calcium-Ion:main' into main 2024-08-28 21:44:54 +08:00
CalciumIon
01fd8b53a6 feat: 检测vertex渠道部署地区是否填写 2024-08-28 18:47:27 +08:00
CalciumIon
e60f200192 feat: 支持vertex ai渠道多个部署地区 2024-08-28 18:43:40 +08:00
GuoRuqiang
033359e93c Merge branch 'Calcium-Ion:main' into main 2024-08-28 10:44:14 +08:00
CalciumIon
c41820541d Update go.mod 2024-08-27 20:30:46 +08:00
CalciumIon
228f0c5ee5 Update README.md 2024-08-27 20:25:55 +08:00
Calcium-Ion
8a5e074f14 Merge pull request #448 from Calcium-Ion/vertex
feat: support vertex ai
2024-08-27 20:21:01 +08:00
CalciumIon
ac4262c542 feat: support vertex ai #377 2024-08-27 20:19:51 +08:00
wozulong
400b2b0ed0 merge upstream
Signed-off-by: wozulong <>
2024-08-26 18:50:07 +08:00
wozulong
27b034674d add encoding_format
Signed-off-by: wozulong <>
2024-08-26 18:39:13 +08:00
GuoRuqiang
1379d7f184 Merge pull request #2 from j471782517/main
增加环境变量GENERATE_DEFAULT_TOKEN 设置之后将生成初始令牌,默认关闭。
2024-08-25 02:53:47 +08:00
Jin Weihan
716bf6f48a 增加环境变量GENERATE_DEFAULT_TOKEN 设置之后将生成初始令牌,默认关闭。 2024-08-24 18:44:37 +00:00
GuoRuqiang
2422eb2820 Merge branch 'Calcium-Ion:main' into main 2024-08-25 01:55:23 +08:00
CalciumIon
46e03683ce fix: channel auto ban 2024-08-24 17:27:14 +08:00
CalciumIon
ff0985f06e fix: channel auto ban #443 2024-08-24 17:23:24 +08:00
CalciumIon
a8ac8a25d5 feat: format claude messages when first role is not user 2024-08-24 17:15:55 +08:00
Xyfacai
5b2082ba58 Merge branch 'main' of https://github.com/Calcium-Ion/new-api 2024-08-24 13:36:44 +08:00
Xyfacai
967ccabb56 fix: 修复 dall-e-2 请求报错 2024-08-24 13:36:41 +08:00
CalciumIon
144513f1d8 feat: rerank model mapping (close #444) 2024-08-23 23:21:37 +08:00
Calcium-Ion
e3087e9bea Merge pull request #445 from OswinWu/fix-outlook-ofb
fix: 多地区outlook邮箱和ofb邮箱Auth
2024-08-23 23:16:37 +08:00
OswinWu
484a8595e4 fix: 多地区outlook邮箱和ofb邮箱Auth 2024-08-23 17:16:09 +08:00
GuoRuqiang
c97e2875b4 增加注册自动生成初始令牌。 2024-08-18 15:12:59 +00:00
GuoRuqiang
64794630c8 修改提示时间。 2024-08-17 16:59:31 +00:00
GuoRuqiang
fc5055c766 update App.js 2024-08-17 16:20:41 +00:00
GuoRuqiang
27eb358497 重新修改了chat 2024-08-17 16:17:24 +00:00
GuoRuqiang
6810ee0a28 Update Chat
修改chat界面,配合nextChat等前端可以自动传入第一个已启用令牌,
2024-08-17 23:09:45 +08:00
CalciumIon
7c4d9d225e feat: support SiliconFlow (close #437, close #403) 2024-08-16 18:27:26 +08:00
CalciumIon
d0f76a5c61 feat: support gpt-4o-gizmo-* (close #436) 2024-08-16 17:25:03 +08:00
CalciumIon
a5ec11e463 fix: add email missing Message-ID 2024-08-16 16:16:38 +08:00
CalciumIon
b3d8e3e9ae fix: lobechat #430 2024-08-16 14:59:32 +08:00
wozulong
fefe5913e9 merge upstream
Signed-off-by: wozulong <>
2024-08-15 22:30:28 +08:00
CalciumIon
0c46d0c7af chore: remove useless code 2024-08-14 22:44:33 +08:00
CalciumIon
8cd8cc29bc fix: log page 'Cannot read properties of undefined (reading 'length')' 2024-08-14 22:43:57 +08:00
CalciumIon
748e34fd10 feat: update openai models list 2024-08-14 15:51:48 +08:00
CalciumIon
f9392ca904 feat: 避免暴露内部错误 2024-08-14 15:49:33 +08:00
CalciumIon
1988c41842 feat: update chatgpt-4o-latest model ratio 2024-08-14 15:47:08 +08:00
CalciumIon
6cb0eb4b39 feat: update claude tools calling 2024-08-13 17:54:24 +08:00
Calcium-Ion
59d06a5576 Merge pull request #427 from QuentinHsu/fix-log-pagination
fix log pagination
2024-08-13 17:50:12 +08:00
Calcium-Ion
1b900e3917 Merge pull request #426 from OswinWu/fix-log-page
Fix log page
2024-08-13 17:50:03 +08:00
Calcium-Ion
accbae3904 Merge pull request #432 from xixingya/feat-add-logdb
Feature: Support Log DB
2024-08-13 17:48:25 +08:00
liuzhifei
d82bd20354 support log db 2024-08-13 10:29:55 +08:00
liuzhifei
0c01f49bc5 add log db 2024-08-13 10:28:35 +08:00
QuentinHsu
9edb7c4ade fix: log pagination 2024-08-11 11:25:32 +08:00
Nothing.
228104e848 Merge branch 'Calcium-Ion:main' into fix-log-page 2024-08-11 11:22:34 +08:00
OswinWu
a2af637e7f fix: log分页问题 2024-08-11 11:21:34 +08:00
QuentinHsu
d6f6403fd3 chore: update @so1ve/prettier-config to version 3.1.0 2024-08-11 11:18:08 +08:00
CalciumIon
4b5303a77b feat: 区分额度不足和预扣费失败提示 2024-08-09 18:48:13 +08:00
CalciumIon
6eab0cc370 feat: 区分额度不足和预扣费失败提示 2024-08-09 18:34:51 +08:00
CalciumIon
9e45dbe964 fix: close #422 2024-08-09 16:14:05 +08:00
Calcium-Ion
e495354823 Merge pull request #425 from dalefengs/fix_group
fix: 渠道多分组查询 sqlite 查询兼容
2024-08-09 15:44:23 +08:00
FENG
9452be51b9 fix: sqlite group 查询兼容 2024-08-09 11:39:19 +08:00
Calcium-Ion
43076c2f33 Merge pull request #415 from dalefengs/fix_group
fix: 渠道多分组,优化分组 like 查询
2024-08-08 20:47:51 +08:00
CalciumIon
04f0084d97 fix: 修复mysql兼容问题 2024-08-08 20:45:41 +08:00
CalciumIon
2e3c266bd6 fix: response format 2024-08-07 15:43:01 +08:00
wozulong
142d5fb209 merge upstream
Signed-off-by: wozulong <>
2024-08-07 14:41:33 +08:00
wozulong
f6ccd402e2 support gpt-4o-2024-08-06
Signed-off-by: wozulong <>
2024-08-07 14:40:57 +08:00
CalciumIon
4490258104 fix bug 2024-08-07 02:50:22 +08:00
CalciumIon
93c6d765c7 feat: support gpt-4o-2024-08-06 2024-08-07 02:49:02 +08:00
FENG
e614ca370a fix: optionList bug 2024-08-06 21:30:20 +08:00
wozulong
1c371300ab merge upstream
Signed-off-by: wozulong <>
2024-08-06 15:54:28 +08:00
FENG
c152b4de08 chore: indent recovery 2024-08-06 15:40:44 +08:00
FENG
190316f66e fix: 渠道多分组,优化分组 like 查询 2024-08-05 22:35:16 +08:00
CalciumIon
67878731fc feat: log user id 2024-08-04 14:35:16 +08:00
CalciumIon
a0a3807bd4 chore: epay 2024-08-04 03:12:24 +08:00
CalciumIon
5d0d268c97 fix: epay 2024-08-04 00:18:32 +08:00
CalciumIon
0b4ef42d86 fix: channel typ error 2024-08-03 22:41:47 +08:00
CalciumIon
0123ad4d61 fix: 重试后request id不一致 2024-08-03 17:46:13 +08:00
CalciumIon
5acf074541 chore: 优化自动禁用代码 2024-08-03 17:32:28 +08:00
Calcium-Ion
8af0d9f22f Merge pull request #409 from utopeadia/main
修改readme错误
2024-08-03 17:31:08 +08:00
HowieWu
afd328efcf 修改readme错误 2024-08-03 17:19:44 +08:00
CalciumIon
dd12a0052f chore: 优化relay代码 2024-08-03 17:12:16 +08:00
CalciumIon
fbe6cd75b1 chore: 优化relay代码 2024-08-03 17:07:14 +08:00
CalciumIon
8a9ff36fbf chore: 优化relay代码 2024-08-03 16:55:29 +08:00
CalciumIon
88ba8a840e feat: 优化充值订单号 2024-08-03 01:28:18 +08:00
CalciumIon
e504665f68 feat: 优化Gemini模型版本获取逻辑 2024-08-02 17:23:59 +08:00
Calcium-Ion
54657ec27b Merge pull request #405 from utopeadia/main
Modify the GEMINI version acquisition logic and add support for more gemini1.5pro/flash interfaces
2024-08-02 17:13:46 +08:00
Calcium-Ion
ae6b4e0be2 Merge pull request #399 from kakingone/main
add-mjp-discord-upload
2024-08-02 17:10:35 +08:00
HowieWu
fc0db4505c Update README.md
增加Gemini版本变量说明
2024-08-02 11:25:41 +08:00
HowieWu
22a98c5879 修改Gemini版本获取逻辑
使用GEMINI_MODEL_API环境变量覆盖默认版本映射,使用","分隔不同模型和版本
-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta,gemini-1.5-pro:v1beta,gemini-1.5-flash-latest:v1beta,gemini-1.5-flash-001:v1beta,gemini-1.5-flash:v1beta,gemini-ultra:v1beta,gemini-1.5-pro-exp-0801:v1beta"
2024-08-02 11:20:26 +08:00
CalciumIon
f8f15bd1d0 fix: rpm模糊查询 2024-08-01 18:14:10 +08:00
CalciumIon
b7690fe17d fix: 日志模糊查询 2024-08-01 18:06:25 +08:00
CalciumIon
58b4c237a4 feat: 优化rpm查询 2024-08-01 17:39:18 +08:00
CalciumIon
54f6e660f1 feat: 优化日志查初始时间 2024-08-01 17:36:26 +08:00
CalciumIon
3b1745c712 feat: 优化日志查询条件 2024-08-01 16:33:59 +08:00
CalciumIon
c92ab3b569 feat: 日志新增rpm和tpm数据。(close #384) 2024-08-01 16:13:08 +08:00
wozulong
918690701d merge upstream
Signed-off-by: wozulong <>
2024-08-01 15:12:54 +08:00
CalciumIon
1501ccb919 fix: error channel name on notify. #338 2024-07-31 18:20:13 +08:00
Calcium-Ion
7f2a2a7de0 Merge pull request #400 from OswinWu/feat-gitignore-web-dist
feat: ignore npm build dir
2024-07-31 17:14:50 +08:00
Calcium-Ion
cce7d0258f Merge pull request #401 from HynoR/main
Support cloudflare llama3.1-8b
2024-07-31 17:14:30 +08:00
TAKO
c5e8d7ec20 Support cloudflare llama3.1-8b 2024-07-31 17:11:25 +08:00
OswinWu
fe16d51fe4 feat: ignore npm build dir 2024-07-31 16:50:19 +08:00
kakingone
2100d8ee0c addupload 2024-07-31 15:48:51 +08:00
CalciumIon
fbce36238e feat: support dify agent 2024-07-30 17:30:40 +08:00
CalciumIon
a6b6bcfe00 chore: remove useless code 2024-07-28 01:12:26 +08:00
CalciumIon
07e55cc999 chore: update token page 2024-07-28 00:05:53 +08:00
CalciumIon
b16e6bf423 fix: panic when get model ratio (close #392) 2024-07-27 18:09:09 +08:00
CalciumIon
b7bc205b73 feat: print user id when error 2024-07-27 17:55:36 +08:00
CalciumIon
88cc88c5d0 feat: support ollama tools 2024-07-27 17:51:05 +08:00
CalciumIon
ab1d61d910 feat: print user id when error 2024-07-27 17:47:30 +08:00
Calcium-Ion
d4a5df7373 Merge pull request #391 from OswinWu/fix-outlook-smtp
[fix] fix send email error using outlook smtp
2024-07-26 20:24:08 +08:00
CalciumIon
9e610c9429 fix: image quota (close #382) 2024-07-26 18:51:34 +08:00
Oswin
da490db6d3 [fix] fix send email error using outlook smtp 2024-07-26 17:47:36 +08:00
1808837298@qq.com
b8291dcd13 fix: gemini 2024-07-23 18:34:16 +08:00
Calcium-Ion
b0d9756c14 Merge pull request #380 from crabkun/main
fix: 修复aws claude渠道panic的问题
2024-07-23 18:22:27 +08:00
Calcium-Ion
9dc07a8585 Merge pull request #383 from Yan-Zero/main
fix: the base64 format image_url for gemini
2024-07-23 18:22:06 +08:00
1808837298@qq.com
caaecb8d54 fix: first login error (close #385) 2024-07-23 18:25:43 +08:00
Yan Tau
b9454c3f14 fix: the base64 format image_url for gemini 2024-07-22 21:20:23 +08:00
crabkun
96bdf97194 fix: 修复aws claude渠道panic的问题 2024-07-21 01:27:29 +08:00
130 changed files with 8763 additions and 3151 deletions

3
.gitignore vendored
View File

@@ -5,4 +5,5 @@ upload
*.db
build
*.db-journal
logs
logs
web/dist

View File

@@ -1,6 +1,13 @@
<div align="center">
![new-api](/web/public/logo.png)
# New API
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
> [!NOTE]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
@@ -54,17 +61,20 @@
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md)
10. Dify
11. Vertex AI目前兼容ClaudeGeminiLlama3.1
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 比原版One API多出的配置
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
- `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL``STRICT`,默认为 `NONE`
## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
@@ -113,22 +123,14 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
## Suno接口设置文档
[对接文档](Suno.md)
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
## 界面截图
![796df8d287b7b7bd7853b2497e7df511](https://github.com/user-attachments/assets/255b5e97-2d3a-4434-b4fa-e922ad88ff5a)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/d1ac216e-0804-4105-9fdc-66b35022d861)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/f4f40ed4-8ccb-43d7-a580-90677827646d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/90d7d763-6a77-4b36-9f76-2bb30f18583d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/e414228a-3c35-429a-b298-6451d76d9032)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
夜间模式
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/5b3228e8-2556-44f7-97d6-4f8d8ee6effa)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## 相关项目

View File

@@ -129,6 +129,9 @@ var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
@@ -140,6 +143,10 @@ const (
RoleRootUser = 100
)
func IsValidateRole(role int) bool {
return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
}
var (
FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser
@@ -236,6 +243,8 @@ const (
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeDummy // this one is only for count, do not add any channel after this
@@ -282,4 +291,6 @@ var ChannelBaseURLs = []string{
"", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
}

View File

@@ -0,0 +1,40 @@
package common
import (
"errors"
"net/smtp"
"strings"
)
type outlookAuth struct {
username, password string
}
func LoginAuth(username, password string) smtp.Auth {
return &outlookAuth{username, password}
}
func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte{}, nil
}
func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:":
return []byte(a.username), nil
case "Password:":
return []byte(a.password), nil
default:
return nil, errors.New("unknown fromServer")
}
}
return nil, nil
}
func isOutlookServer(server string) bool {
// 兼容多地区的outlook邮箱和ofb邮箱
// 其实应该加一个Option来区分是否用LOGIN的方式登录
// 先临时兼容一下
return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
}

View File

@@ -9,17 +9,26 @@ import (
"time"
)
func generateMessageID() string {
domain := strings.Split(SMTPAccount, "@")[1]
return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain)
}
func SendEmail(subject string, receiver string, content string) error {
if SMTPFrom == "" { // for compatibility
SMTPFrom = SMTPAccount
}
if SMTPServer == "" && SMTPAccount == "" {
return fmt.Errorf("SMTP 服务器未配置")
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
"Date: %s\r\n"+
"Message-ID: %s\r\n"+ // 添加 Message-ID 头
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), content))
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), generateMessageID(), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")
@@ -62,6 +71,9 @@ func SendEmail(subject string, receiver string, content string) error {
if err != nil {
return err
}
} else if isOutlookServer(SMTPAccount) {
auth = LoginAuth(SMTPAccount, SMTPToken)
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
} else {
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
}

View File

@@ -3,6 +3,7 @@ package common
import (
"encoding/json"
"strings"
"sync"
)
// from songquanpeng/one-api
@@ -30,10 +31,16 @@ var defaultModelRatio = map[string]float64{
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075,
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075,
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
"o1-preview": 7.5,
"o1-preview-2024-09-12": 7.5,
"o1-mini": 1.5,
"o1-mini-2024-09-12": 1.5,
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
@@ -73,13 +80,13 @@ var defaultModelRatio = map[string]float64{
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
@@ -101,8 +108,10 @@ var defaultModelRatio = map[string]float64{
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
@@ -114,6 +123,13 @@ var defaultModelRatio = map[string]float64{
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
@@ -132,26 +148,28 @@ var defaultModelRatio = map[string]float64{
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.18,
"yi-34b-chat-200k": 0.864,
"yi-vl-plus": 0.432,
"yi-large": 20.0 / 1000 * RMB,
"yi-medium": 2.5 / 1000 * RMB,
"yi-vision": 6.0 / 1000 * RMB,
"yi-medium-200k": 12.0 / 1000 * RMB,
"yi-spark": 1.0 / 1000 * RMB,
"yi-large-rag": 25.0 / 1000 * RMB,
"yi-large-turbo": 12.0 / 1000 * RMB,
"yi-large-preview": 20.0 / 1000 * RMB,
"yi-large-rag-preview": 25.0 / 1000 * RMB,
"command": 0.5,
"command-nightly": 0.5,
"command-light": 0.5,
"command-light-nightly": 0.5,
"command-r": 0.25,
"command-r-plus ": 1.5,
"deepseek-chat": 0.07,
"deepseek-coder": 0.07,
"yi-34b-chat-0205": 0.18,
"yi-34b-chat-200k": 0.864,
"yi-vl-plus": 0.432,
"yi-large": 20.0 / 1000 * RMB,
"yi-medium": 2.5 / 1000 * RMB,
"yi-vision": 6.0 / 1000 * RMB,
"yi-medium-200k": 12.0 / 1000 * RMB,
"yi-spark": 1.0 / 1000 * RMB,
"yi-large-rag": 25.0 / 1000 * RMB,
"yi-large-turbo": 12.0 / 1000 * RMB,
"yi-large-preview": 20.0 / 1000 * RMB,
"yi-large-rag-preview": 25.0 / 1000 * RMB,
"command": 0.5,
"command-nightly": 0.5,
"command-light": 0.5,
"command-light-nightly": 0.5,
"command-r": 0.25,
"command-r-plus": 1.5,
"command-r-08-2024": 0.075,
"command-r-plus-08-2024": 1.25,
"deepseek-chat": 0.07,
"deepseek-coder": 0.07,
// Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用
"llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD,
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
@@ -181,10 +199,17 @@ var defaultModelPrice = map[string]float64{
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
"mj_upload": 0.05,
}
var modelPrice map[string]float64 = nil
var modelRatio map[string]float64 = nil
var (
modelPriceMap map[string]float64 = nil
modelPriceMapMutex = sync.RWMutex{}
)
var (
modelRatioMap map[string]float64 = nil
modelRatioMapMutex = sync.RWMutex{}
)
var CompletionRatio map[string]float64 = nil
var defaultCompletionRatio = map[string]float64{
@@ -194,11 +219,18 @@ var defaultCompletionRatio = map[string]float64{
"gpt-4o-all": 2,
}
func ModelPrice2JSONString() string {
if modelPrice == nil {
modelPrice = defaultModelPrice
func GetModelPriceMap() map[string]float64 {
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
if modelPriceMap == nil {
modelPriceMap = defaultModelPrice
}
jsonBytes, err := json.Marshal(modelPrice)
return modelPriceMap
}
func ModelPrice2JSONString() string {
GetModelPriceMap()
jsonBytes, err := json.Marshal(modelPriceMap)
if err != nil {
SysError("error marshalling model price: " + err.Error())
}
@@ -206,21 +238,21 @@ func ModelPrice2JSONString() string {
}
func UpdateModelPriceByJSONString(jsonStr string) error {
modelPrice = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPrice)
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
modelPriceMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
}
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) {
if modelPrice == nil {
modelPrice = defaultModelPrice
}
GetModelPriceMap()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
}
price, ok := modelPrice[name]
price, ok := modelPriceMap[name]
if !ok {
if printErr {
SysError("model price not found: " + name)
@@ -230,18 +262,18 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
return price, true
}
func GetModelPriceMap() map[string]float64 {
if modelPrice == nil {
modelPrice = defaultModelPrice
func GetModelRatioMap() map[string]float64 {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
if modelRatioMap == nil {
modelRatioMap = defaultModelRatio
}
return modelPrice
return modelRatioMap
}
func ModelRatio2JSONString() string {
if modelRatio == nil {
modelRatio = defaultModelRatio
}
jsonBytes, err := json.Marshal(modelRatio)
GetModelRatioMap()
jsonBytes, err := json.Marshal(modelRatioMap)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
}
@@ -249,20 +281,20 @@ func ModelRatio2JSONString() string {
}
func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatio)
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
}
func GetModelRatio(name string) float64 {
if modelRatio == nil {
modelRatio = defaultModelRatio
}
GetModelRatioMap()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
}
ratio, ok := modelRatio[name]
ratio, ok := modelRatioMap[name]
if !ok {
SysError("model ratio not found: " + name)
return 30
@@ -318,15 +350,22 @@ func GetCompletionRatio(name string) float64 {
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" {
if strings.HasPrefix(name, "gpt-4o-mini") {
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || "gpt-4o-2024-05-13" == name {
return 3
}
if strings.HasPrefix(name, "gpt-4o") {
return 4
}
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || strings.HasPrefix(name, "gpt-4o") {
return 3
}
return 2
}
if "o1" == name || strings.HasPrefix(name, "o1-") {
return 4
}
if name == "chatgpt-4o-latest" {
return 3
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3
} else if strings.HasPrefix(name, "claude-2") {
@@ -338,7 +377,7 @@ func GetCompletionRatio(name string) float64 {
return 3
}
if strings.HasPrefix(name, "gemini-") {
return 3
return 4
}
if strings.HasPrefix(name, "command") {
switch name {
@@ -346,6 +385,10 @@ func GetCompletionRatio(name string) float64 {
return 3
case "command-r-plus":
return 5
case "command-r-08-2024":
return 4
case "command-r-plus-08-2024":
return 4
default:
return 2
}

View File

@@ -31,14 +31,6 @@ func MapToJsonStr(m map[string]interface{}) string {
return string(bytes)
}
func MapToJsonStrFloat(m map[string]float64) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
@@ -48,6 +40,11 @@ func StrToMap(str string) map[string]interface{} {
return m
}
func IsJsonStr(str string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {

28
common/user_groups.go Normal file
View File

@@ -0,0 +1,28 @@
package common
import (
"encoding/json"
)
var UserUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(UserUsableGroups)
if err != nil {
SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
UserUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
}
func GroupInUserUsableGroups(groupName string) bool {
_, ok := UserUsableGroups[groupName]
return ok
}

View File

@@ -3,11 +3,14 @@ package common
import (
"context"
"errors"
crand "crypto/rand"
"encoding/base64"
"fmt"
"github.com/google/uuid"
"golang.org/x/net/proxy"
"html/template"
"log"
"math/big"
"math/rand"
"net"
"net/http"
@@ -133,6 +136,11 @@ func IntMax(a int, b int) int {
}
}
func IsIP(s string) bool {
ip := net.ParseIP(s)
return ip != nil
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
@@ -142,24 +150,35 @@ func GetUUID() string {
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
rand.New(rand.NewSource(time.Now().UnixNano()))
}
func GenerateKey() string {
//rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
func GenerateRandomCharsKey(length int) (string, error) {
b := make([]byte, length)
maxI := big.NewInt(int64(len(keyChars)))
for i := range b {
n, err := crand.Int(crand.Reader, maxI)
if err != nil {
return "", err
}
key[i+16] = c
b[i] = keyChars[n.Int64()]
}
return string(key)
return string(b), nil
}
func GenerateRandomKey(length int) (string, error) {
bytes := make([]byte, length*3/4) // 对于48位的输出这里应该是36
if _, err := crand.Read(bytes); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(bytes), nil
}
func GenerateKey() (string, error) {
//rand.Seed(time.Now().UnixNano())
return GenerateRandomCharsKey(48)
}
func GetRandomInt(max int) int {

35
constant/chat.go Normal file
View File

@@ -0,0 +1,35 @@
package constant
import (
"encoding/json"
"one-api/common"
)
var Chats = []map[string]string{
{
"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
},
{
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
},
{
"AMA 问天": "ama://set-api-key?server={address}&key={key}",
},
{
"OpenCat": "opencat://team/join?domain={address}&token={key}",
},
}
func UpdateChatsByJsonString(jsonString string) error {
Chats = make([]map[string]string, 0)
return json.Unmarshal([]byte(jsonString), &Chats)
}
func Chats2JsonString() string {
jsonBytes, err := json.Marshal(Chats)
if err != nil {
common.SysError("error marshalling chats: " + err.Error())
return "[]"
}
return string(jsonBytes)
}

View File

@@ -1,7 +1,10 @@
package constant
import (
"fmt"
"one-api/common"
"os"
"strings"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
@@ -15,3 +18,34 @@ var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-pro-exp-0827": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash-exp-0827": "v1beta",
"gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
}
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
}
// 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

View File

@@ -27,6 +27,7 @@ const (
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
)
var MidjourneyModel2Action = map[string]string{
@@ -45,4 +46,5 @@ var MidjourneyModel2Action = map[string]string{
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
}

View File

@@ -20,6 +20,7 @@ import (
"one-api/relay/constant"
"one-api/service"
"strconv"
"strings"
"sync"
"time"
@@ -81,8 +82,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
request := buildTestRequest()
request.Model = testModel
request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
@@ -141,17 +141,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return nil, nil
}
func buildTestRequest() *dto.GeneralOpenAIRequest {
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later
MaxTokens: 1,
Stream: false,
Model: "", // this will be set later
Stream: false,
}
if "o1" == model || strings.HasPrefix(model, "o1-") {
testRequest.MaxCompletionTokens = 1
} else {
testRequest.MaxTokens = 1
}
content, _ := json.Marshal("hi")
testMessage := dto.Message{
Role: "user",
Content: content,
}
testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
@@ -226,26 +231,22 @@ func testAllChannels(notify bool) error {
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
ban := false
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
ban = true
}
shouldBanChannel := false
// request error disables the channel
if openaiWithStatusErr != nil {
oaiErr := openaiWithStatusErr.Error
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
}
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
shouldBanChannel = true
}
// disable channel
if ban && isChannelEnabled {
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}

View File

@@ -198,6 +198,28 @@ func AddChannel(c *gin.Context) {
}
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
keys = []string{channel.Key}
}
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
@@ -297,6 +319,27 @@ func UpdateChannel(c *gin.Context) {
})
return
}
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
}
err = channel.Update()
if err != nil {
c.JSON(http.StatusOK, gin.H{

View File

@@ -112,7 +112,9 @@ func GitHubOAuth(c *gin.Context) {
user := model.User{
GitHubId: githubUser.Login,
}
// IsGitHubIdAlreadyTaken is unscoped
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
// FillUserByGitHubId is scoped
err := user.FillUserByGitHubId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -121,6 +123,14 @@ func GitHubOAuth(c *gin.Context) {
})
return
}
// if user.Id == 0 , user has been deleted
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))

View File

@@ -17,3 +17,18 @@ func GetGroups(c *gin.Context) {
"data": groupNames,
})
}
func GetUserGroups(c *gin.Context) {
usableGroups := make(map[string]string)
for groupName, _ := range common.GroupRatio {
// UserUsableGroups contains the groups that the user can use
if _, ok := common.UserUsableGroups[groupName]; ok {
usableGroups[groupName] = common.UserUsableGroups[groupName]
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": usableGroups,
})
}

View File

@@ -132,6 +132,14 @@ func LinuxDoOAuth(c *gin.Context) {
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {

View File

@@ -1,18 +1,19 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-gonic/gin"
)
func GetAllLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 {
p = 0
if p < 1 {
p = 1
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
@@ -24,7 +25,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -35,17 +36,20 @@ func GetAllLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"total": total,
"data": logs,
"data": map[string]any{
"items": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}
func GetUserLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 {
p = 0
if p < 1 {
p = 1
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
@@ -59,7 +63,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -70,8 +74,12 @@ func GetUserLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"total": total,
"data": logs,
"data": map[string]any{
"items": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}
@@ -184,7 +192,7 @@ func DeleteHistoryLogs(c *gin.Context) {
})
return
}
count, err := model.DeleteOldLog(targetTimestamp)
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -65,6 +65,7 @@ func GetStatus(c *gin.Context) {
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"payment_enabled": common.PaymentEnabled,
"mj_notify_enabled": constant.MjNotifyEnabled,
"chats": constant.Chats,
},
})
return

View File

@@ -137,31 +137,63 @@ func init() {
}
func ListModels(c *gin.Context) {
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
models := model.GetGroupModels(user.Group)
userOpenAiModels := make([]dto.OpenAIModels, 0)
permission := getPermission()
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: s,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: s,
Parent: nil,
tokenModelLimit = map[string]bool{}
}
for allowModel, _ := range tokenModelLimit {
if _, ok := openAIModelsMap[allowModel]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: allowModel,
Parent: nil,
})
}
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "get user group failed",
})
return
}
group := userGroup
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
group = tokenGroup
}
models := model.GetGroupModels(group)
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: s,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: s,
Parent: nil,
})
}
}
}
c.JSON(200, gin.H{

View File

@@ -7,18 +7,11 @@ import (
)
func GetPricing(c *gin.Context) {
userId := c.GetInt("id")
// if no login, get default group ratio
groupRatio := common.GetGroupRatio("default")
group, err := model.CacheGetUserGroup(userId)
if err == nil {
groupRatio = common.GetGroupRatio(group)
}
pricing := model.GetPricing(group)
pricing := model.GetPricing()
c.JSON(200, gin.H{
"success": true,
"data": pricing,
"group_ratio": groupRatio,
"group_ratio": common.GroupRatio,
})
}

View File

@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
@@ -37,45 +38,89 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
return err
}
func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode
defer func() {
if openaiErr != nil {
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}()
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
return
}
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
return
}
if playgroundRequest.Model == "" {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
return
}
c.Set("original_model", playgroundRequest.Model)
group := playgroundRequest.Group
userGroup := c.GetString("group")
if group == "" {
group = userGroup
} else {
if !common.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
}
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
Relay(c)
}
func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
retryTimes := common.RetryTimes
requestId := c.GetString(common.RequestIdKey)
channelId := c.GetInt("channel_id")
channelType := c.GetInt("channel_type")
group := c.GetString("group")
originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode)
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
if openaiErr != nil {
go processChannelError(c, channelId, channelType, openaiErr)
} else {
retryTimes = 0
}
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
openaiErr = relayHandler(c, relayMode)
if openaiErr != nil {
go processChannelError(c, channelId, channel.Type, openaiErr)
openaiErr = relayRequest(c, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
@@ -89,10 +134,48 @@ func Relay(c *gin.Context) {
}
}
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
if retryCount == 0 {
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
if !autoBan {
autoBanInt = 0
}
return &model.Channel{
Id: c.GetInt("channel_id"),
Type: c.GetInt("channel_type"),
Name: c.GetString("channel_name"),
AutoBan: &autoBanInt,
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
if openaiErr == nil {
return false
}
if openaiErr.LocalError {
return false
}
if retryTimes <= 0 {
return false
}
@@ -113,26 +196,27 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
return true
}
return false
}
if openaiErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
if openaiErr.LocalError {
return false
}
if openaiErr.StatusCode/100 == 2 {
return false
}
return true
}
func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(channelType, err) && autoBan {
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message)
}
}
@@ -208,14 +292,14 @@ func RelayTask(c *gin.Context) {
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
@@ -225,7 +309,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)
common.LogInfo(c, retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {

View File

@@ -2,8 +2,8 @@ package controller
import (
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/webhook"
"github.com/stripe/stripe-go/v79"
"github.com/stripe/stripe-go/v79/webhook"
"io"
"log"
"net/http"
@@ -23,7 +23,9 @@ func StripeWebhook(c *gin.Context) {
signature := c.GetHeader("Stripe-Signature")
endpointSecret := common.StripeWebhookSecret
event, err := webhook.ConstructEvent(payload, signature, endpointSecret)
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
IgnoreAPIVersionMismatch: true,
})
if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err)

View File

@@ -5,6 +5,7 @@ import (
"crypto/sha256"
"encoding/hex"
"io"
"net/http"
"one-api/common"
"one-api/model"
"sort"
@@ -48,6 +49,13 @@ func TelegramBind(c *gin.Context) {
})
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
user.TelegramId = telegramId
if err := user.Update(false); err != nil {
c.JSON(200, gin.H{

View File

@@ -123,10 +123,19 @@ func AddToken(c *gin.Context) {
})
return
}
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
Key: common.GenerateKey(),
Key: key,
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
@@ -134,6 +143,8 @@ func AddToken(c *gin.Context) {
UnlimitedQuota: token.UnlimitedQuota,
ModelLimitsEnabled: token.ModelLimitsEnabled,
ModelLimits: token.ModelLimits,
AllowIps: token.AllowIps,
Group: token.Group,
}
err = cleanToken.Insert()
if err != nil {
@@ -221,6 +232,8 @@ func UpdateToken(c *gin.Context) {
cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
cleanToken.ModelLimits = token.ModelLimits
cleanToken.AllowIps = token.AllowIps
cleanToken.Group = token.Group
}
err = cleanToken.Update()
if err != nil {

View File

@@ -4,8 +4,8 @@ import "C"
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/checkout/session"
"github.com/stripe/stripe-go/v79"
"github.com/stripe/stripe-go/v79/checkout/session"
"log"
"one-api/common"
"one-api/model"

View File

@@ -7,10 +7,12 @@ import (
"one-api/common"
"one-api/model"
"strconv"
"strings"
"sync"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"one-api/constant"
)
type LoginRequest struct {
@@ -66,6 +68,7 @@ func setupLogin(user *model.User, c *gin.Context) {
session.Set("username", user.Username)
session.Set("role", user.Role)
session.Set("status", user.Status)
session.Set("group", user.Group)
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
err := session.Save()
if err != nil {
@@ -158,8 +161,9 @@ func Register(c *gin.Context) {
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"message": "数据库错误,请稍后重试",
})
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return
}
if exist {
@@ -187,6 +191,48 @@ func Register(c *gin.Context) {
})
return
}
// 获取插入后的用户ID
var insertedUser model.User
if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户注册失败或用户ID获取失败",
})
return
}
// 生成默认令牌
if constant.GenerateDefaultToken {
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成默认令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
// 生成默认令牌
token := model.Token{
UserId: insertedUser.Id, // 使用插入后的用户ID
Name: cleanUser.Username + "的初始令牌",
Key: key,
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: -1, // 永不过期
RemainQuota: 500000, // 示例额度
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "创建默认令牌失败",
})
return
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -277,7 +323,18 @@ func GenerateAccessToken(c *gin.Context) {
})
return
}
user.AccessToken = common.GetUUID()
// get rand int 28-32
randI := common.GetRandomInt(4)
key, err := common.GenerateRandomKey(29 + randI)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成失败",
})
common.SysError("failed to generate key: " + err.Error())
return
}
user.SetAccessToken(key)
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{
@@ -597,6 +654,7 @@ func DeleteSelf(c *gin.Context) {
func CreateUser(c *gin.Context) {
var user model.User
err := json.NewDecoder(c.Request.Body).Decode(&user)
user.Username = strings.TrimSpace(user.Username)
if err != nil || user.Username == "" || user.Password == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -644,8 +702,8 @@ func CreateUser(c *gin.Context) {
}
type ManageRequest struct {
Username string `json:"username"`
Action string `json:"action"`
Id int `json:"id"`
Action string `json:"action"`
}
// ManageUser Only admin user can do this
@@ -661,7 +719,7 @@ func ManageUser(c *gin.Context) {
return
}
user := model.User{
Username: req.Username,
Id: req.Id,
}
// Fill attributes
model.DB.Unscoped().Where(&user).First(&user)
@@ -806,11 +864,11 @@ type topUpRequest struct {
Key string `json:"key"`
}
var lock = sync.Mutex{}
var topUpLock = sync.Mutex{}
func TopUp(c *gin.Context) {
lock.Lock()
defer lock.Unlock()
topUpLock.Lock()
defer topUpLock.Unlock()
req := topUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {

View File

@@ -78,6 +78,13 @@ func WeChatAuth(c *gin.Context) {
})
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)

View File

@@ -33,6 +33,12 @@ type MidjourneyResponse struct {
Result string `json:"result"`
}
type MidjourneyUploadResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result []string `json:"result"`
}
type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"`
Response MidjourneyResponse

6
dto/playground.go Normal file
View File

@@ -0,0 +1,6 @@
package dto
type PlayGroundRequest struct {
Model string `json:"model,omitempty"`
Group string `json:"group,omitempty"`
}

View File

@@ -1,14 +1,17 @@
package dto
type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`
}
type RerankResponseDocument struct {
Document any `json:"document"`
Document any `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}

View File

@@ -2,40 +2,39 @@ package dto
import "encoding/json"
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
BestOf int `json:"best_of,omitempty"`
Echo bool `json:"echo,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Suffix string `json:"suffix,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogitBias any `json:"logit_bias,omitempty"`
LogProbs any `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
BestOf int `json:"best_of,omitempty"`
Echo bool `json:"echo,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Suffix string `json:"suffix,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogitBias any `json:"logit_bias,omitempty"`
LogProbs any `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
ParallelToolCalls bool `json:"parallel_Tool_Calls,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
}
type OpenAITools struct {

View File

@@ -34,6 +34,7 @@ type OpenAITextResponseChoice struct {
type OpenAITextResponse struct {
Id string `json:"id"`
Model string `json:"model"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
@@ -41,9 +42,9 @@ type OpenAITextResponse struct {
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
Object string `json:"object"`
Index int `json:"index"`
Embedding any `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {

24
go.mod
View File

@@ -1,13 +1,16 @@
module one-api
// +heroku goVersion go1.18
go 1.18
go 1.21
toolchain go1.22.4
require (
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
@@ -23,9 +26,10 @@ require (
github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/stripe/stripe-go/v76 v76.21.0
golang.org/x/crypto v0.21.0
github.com/stripe/stripe-go/v79 v79.12.0
golang.org/x/crypto v0.26.0
golang.org/x/image v0.15.0
golang.org/x/net v0.28.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
@@ -38,9 +42,8 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
@@ -51,6 +54,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
@@ -68,6 +72,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/stretchr/testify v1.9.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
@@ -75,10 +80,9 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.24.0 // indirect
golang.org/x/text v0.17.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

52
go.sum
View File

@@ -21,8 +21,8 @@ github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaU
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -32,11 +32,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
@@ -57,6 +56,7 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -81,10 +81,9 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -128,8 +127,6 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/linux-do/tiktoken-go v0.7.0 h1:Kcm/miJ5gp77srtF8GQWnfq7W9kTaXEuHZg/g9IVEu8=
github.com/linux-do/tiktoken-go v0.7.0/go.mod h1:9Vkdtp0ngi4USmrdSx984iuIQ5IMr0hnUdz4jZZTJb8=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@@ -144,16 +141,17 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -176,9 +174,10 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stripe/stripe-go/v76 v76.21.0 h1:O3GHImHS4oUI3qWMOClHN3zAQF5/oswS/NB7leV1fsU=
github.com/stripe/stripe-go/v76 v76.21.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stripe/stripe-go/v79 v79.12.0 h1:HQs/kxNEB3gYA7FnkSFkp0kSOeez0fsmCWev6SxftYs=
github.com/stripe/stripe-go/v79 v79.12.0/go.mod h1:cuH6X0zC8peY6f1AubHwgJ/fJSn2dh5pfiCr6CjyKVU=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@@ -197,19 +196,19 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -222,26 +221,27 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=

View File

@@ -42,6 +42,11 @@ func main() {
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
defer func() {
err := model.CloseDB()
if err != nil {
@@ -55,6 +60,8 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize constants
constant.InitEnv()
// Initialize options
model.InitOptionMap()
if common.RedisEnabled {

View File

@@ -10,6 +10,17 @@ import (
"strings"
)
func validUserInfo(username string, role int) bool {
// check username is empty
if strings.TrimSpace(username) == "" {
return false
}
if !common.IsValidateRole(role) {
return false
}
return true
}
func authHelper(c *gin.Context, minRole int) {
session := sessions.Default(c)
username := session.Get("username")
@@ -31,6 +42,14 @@ func authHelper(c *gin.Context, minRole int) {
}
user := model.ValidateAccessToken(accessToken)
if user != nil && user.Username != "" {
if !validUserInfo(user.Username, user.Role) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
// Token is valid
username = user.Username
role = user.Role
@@ -101,9 +120,19 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort()
return
}
if !validUserInfo(username.(string), role.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
c.Set("username", username)
c.Set("role", role)
c.Set("id", id)
c.Set("group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
c.Next()
}
@@ -153,6 +182,12 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0]
}
token, err := model.ValidateUserToken(key)
if token != nil {
id := c.GetInt("id")
if id == 0 {
c.Set("id", token.UserId)
}
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
@@ -188,6 +223,8 @@ func TokenAuth() func(c *gin.Context) {
} else {
c.Set("token_model_limit_enabled", false)
}
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1])

View File

@@ -22,6 +22,14 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
allowIpsMap := c.GetStringMap("allow_ips")
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
userId := c.GetInt("id")
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
@@ -31,6 +39,20 @@ func Distribute() func(c *gin.Context) {
return
}
userGroup, _ := model.CacheGetUserGroup(userId)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := common.UserUsableGroups[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if _, ok := common.GroupRatio[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
userGroup = tokenGroup
}
c.Set("group", userGroup)
if ok {
id, err := strconv.Atoi(channelId.(string))
@@ -184,19 +206,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
if channel == nil {
return
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", ban)
c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
@@ -205,6 +221,8 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeVertexAi:
c.Set("region", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeGemini:

View File

@@ -1,11 +1,13 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
)
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
userId := c.GetInt("id")
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
@@ -13,7 +15,7 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
},
})
c.Abort()
common.LogError(c.Request.Context(), message)
common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
}
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {

View File

@@ -36,6 +36,12 @@ func GetEnabledModels() []string {
return models
}
func GetAllEnableAbilities() []Ability {
var abilities []Ability
DB.Find(&abilities, "enabled = ?", true)
return abilities
}
func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`"
trueVal := "1"

View File

@@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
channel.OtherInfo = string(otherInfoBytes)
}
func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
}
return *channel.AutoBan == 1
}
func (channel *Channel) Save() error {
return DB.Save(channel).Error
}
@@ -99,16 +106,23 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
// 构造WHERE子句
var whereClause string
var args []interface{}
if group != "" {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%")
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%", "%,"+group+",%")
} else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
}
// 执行查询
err := baseQuery.Where(whereClause, args...).Find(&channels).Error
err := baseQuery.Where(whereClause, args...).Order("priority desc").Find(&channels).Error
if err != nil {
return nil, err
}

View File

@@ -3,10 +3,12 @@ package model
import (
"context"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
type Log struct {
@@ -37,7 +39,7 @@ const (
)
func GetLogByKey(key string) (logs []*Log, err error) {
err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
return logs, err
}
@@ -53,7 +55,7 @@ func RecordLog(userId int, logType int, content string) {
Type: logType,
Content: content,
}
err := DB.Create(log).Error
err := LOG_DB.Create(log).Error
if err != nil {
common.SysError("failed to record log: " + err.Error())
}
@@ -83,7 +85,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
IsStream: isStream,
Other: otherStr,
}
err := DB.Create(log).Error
err := LOG_DB.Create(log).Error
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
}
@@ -97,12 +99,12 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB
tx = LOG_DB
} else {
tx = DB.Where("type = ?", logType)
tx = LOG_DB.Where("type = ?", logType)
}
if modelName != "" {
tx = tx.Where("model_name like ?", "%"+modelName+"%")
tx = tx.Where("model_name like ?", modelName)
}
if username != "" {
tx = tx.Where("username = ?", username)
@@ -119,25 +121,26 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
return nil, 0, err
}
return logs, total, err
}
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB.Where("user_id = ?", userId)
tx = LOG_DB.Where("user_id = ?", userId)
} else {
tx = DB.Where("user_id = ? and type = ?", userId, logType)
tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
tx = tx.Where("model_name like ?", modelName)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
@@ -148,14 +151,11 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
return logs, total, err
for i := range logs {
var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other)
@@ -169,12 +169,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err
}
@@ -185,12 +185,18 @@ type Stat struct {
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" {
tx = tx.Where("username = ?", username)
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
}
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
@@ -199,17 +205,29 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx = tx.Where("created_at <= ?", endTimestamp)
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
tx = tx.Where("model_name like ?", modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName)
}
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
}
tx.Where("type = ?", LogTypeConsume).Scan(&stat)
tx = tx.Where("type = ?", LogTypeConsume)
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
// 只统计最近60秒的rpm和tpm
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
// 执行查询
tx.Scan(&stat)
rpmTpmQuery.Scan(&stat)
return stat
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -229,7 +247,25 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
return token
}
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) {
var total int64 = 0
for {
if nil != ctx.Err() {
return total, ctx.Err()
}
result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{})
if nil != result.Error {
return total, result.Error
}
total += result.RowsAffected
if result.RowsAffected < int64(limit) {
break
}
}
return total, nil
}

View File

@@ -15,6 +15,8 @@ import (
var DB *gorm.DB
var LOG_DB *gorm.DB
func createRootAccountIfNeed() error {
var user User
//if user.Status != common.UserStatusEnabled {
@@ -30,7 +32,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: common.GetUUID(),
AccessToken: nil,
Quota: 100000000,
}
DB.Create(&rootUser)
@@ -38,9 +40,9 @@ func createRootAccountIfNeed() error {
return nil
}
func chooseDB() (*gorm.DB, error) {
if os.Getenv("SQL_DSN") != "" {
dsn := os.Getenv("SQL_DSN")
func chooseDB(envName string) (*gorm.DB, error) {
dsn := os.Getenv(envName)
if dsn != "" {
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
@@ -52,6 +54,13 @@ func chooseDB() (*gorm.DB, error) {
PrepareStmt: true, // precompile SQL
})
}
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL
common.SysLog("using MySQL as database")
// check parseTime
@@ -76,7 +85,7 @@ func chooseDB() (*gorm.DB, error) {
}
func InitDB() (err error) {
db, err := chooseDB()
db, err := chooseDB("SQL_DSN")
if err == nil {
if common.DebugEnabled {
db = db.Debug()
@@ -100,52 +109,7 @@ func InitDB() (err error) {
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//}
common.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return err
}
err = db.AutoMigrate(&Token{})
if err != nil {
return err
}
err = db.AutoMigrate(&User{})
if err != nil {
return err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return err
}
err = db.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = db.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = db.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
err = db.AutoMigrate(&Task{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
err = migrateDB()
return err
} else {
common.FatalLog(err)
@@ -153,8 +117,103 @@ func InitDB() (err error) {
return err
}
func CloseDB() error {
sqlDB, err := DB.DB()
func InitLogDB() (err error) {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
db, err := chooseDB("LOG_SQL_DSN")
if err == nil {
if common.DebugEnabled {
db = db.Debug()
}
LOG_DB = db
sqlDB, err := LOG_DB.DB()
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil
}
//if common.UsingMySQL {
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//}
common.SysLog("database migration started")
err = migrateLOGDB()
return err
} else {
common.FatalLog(err)
}
return err
}
func migrateDB() error {
err := DB.AutoMigrate(&Channel{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Token{})
if err != nil {
return err
}
err = DB.AutoMigrate(&User{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Option{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Log{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = DB.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = DB.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Task{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func closeDB(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
@@ -162,6 +221,16 @@ func CloseDB() error {
return err
}
func CloseDB() error {
if LOG_DB != DB {
err := closeDB(LOG_DB)
if err != nil {
return err
}
}
return closeDB(DB)
}
var (
lastPingTime time.Time
pingMutex sync.Mutex

View File

@@ -70,6 +70,7 @@ func InitOptionMap() {
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = constant.Chats2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["LinuxDoClientId"] = ""
@@ -90,6 +91,7 @@ func InitOptionMap() {
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
@@ -251,6 +253,8 @@ func updateOptionMap(key string, value string) (err error) {
common.ServerAddress = value
case "OutProxyUrl":
common.OutProxyUrl = value
case "Chats":
err = constant.UpdateChatsByJsonString(value)
case "StripeApiSecret":
common.StripeApiSecret = value
case "StripeWebhookSecret":
@@ -315,6 +319,8 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = common.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":

View File

@@ -7,14 +7,13 @@ import (
)
type Pricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
EnableGroup []string `json:"enable_groups,omitempty"`
}
var (
@@ -23,40 +22,47 @@ var (
updatePricingLock sync.Mutex
)
func GetPricing(group string) []Pricing {
func GetPricing() []Pricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing()
}
if group != "" {
userPricingMap := make([]Pricing, 0)
models := GetGroupModels(group)
for _, pricing := range pricingMap {
if !common.StringsContains(models, pricing.ModelName) {
pricing.Available = false
}
userPricingMap = append(userPricingMap, pricing)
}
return userPricingMap
}
//if group != "" {
// userPricingMap := make([]Pricing, 0)
// models := GetGroupModels(group)
// for _, pricing := range pricingMap {
// if !common.StringsContains(models, pricing.ModelName) {
// pricing.Available = false
// }
// userPricingMap = append(userPricingMap, pricing)
// }
// return userPricingMap
//}
return pricingMap
}
func updatePricing() {
//modelRatios := common.GetModelRatios()
enabledModels := GetEnabledModels()
allModels := make(map[string]int)
for i, model := range enabledModels {
allModels[model] = i
enableAbilities := GetAllEnableAbilities()
modelGroupsMap := make(map[string][]string)
for _, ability := range enableAbilities {
groups := modelGroupsMap[ability.Model]
if groups == nil {
groups = make([]string, 0)
}
if !common.StringsContains(groups, ability.Group) {
groups = append(groups, ability.Group)
}
modelGroupsMap[ability.Model] = groups
}
pricingMap = make([]Pricing, 0)
for model, _ := range allModels {
for model, groups := range modelGroupsMap {
pricing := Pricing{
Available: true,
ModelName: model,
ModelName: model,
EnableGroup: groups,
}
modelPrice, findPrice := common.GetModelPrice(model, false)
if findPrice {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
relaycommon "one-api/relay/common"
"strconv"
"strings"
)
@@ -22,10 +23,34 @@ type Token struct {
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Group string `json:"group" gorm:"default:''"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (token *Token) GetIpLimitsMap() map[string]any {
// delete empty spaces
//split with \n
ipLimitsMap := make(map[string]any)
if token.AllowIps == nil {
return ipLimitsMap
}
cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
if cleanIps == "" {
return ipLimitsMap
}
ips := strings.Split(cleanIps, "\n")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
ip = strings.ReplaceAll(ip, ",", "")
if common.IsIP(ip) {
ipLimitsMap[ip] = true
}
}
return ipLimitsMap
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
var tokens []*Token
var err error
@@ -50,12 +75,12 @@ func ValidateUserToken(key string) (token *Token, err error) {
if token.Status == common.TokenStatusExhausted {
keyPrefix := key[:3]
keySuffix := key[len(key)-3:]
return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
return token, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
return token, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if !common.RedisEnabled {
@@ -65,7 +90,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌已过期")
return token, errors.New("该令牌已过期")
}
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
@@ -78,7 +103,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
}
keyPrefix := key[:3]
keySuffix := key[len(key)-3:]
return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
}
return token, nil
}
@@ -129,7 +154,8 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
return err
}
@@ -231,51 +257,56 @@ func decreaseTokenQuota(id int, quota int) (err error) {
return err
}
func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) {
if quota < 0 {
return 0, errors.New("quota 不能为负数!")
}
token, err := GetTokenById(tokenId)
if err != nil {
return 0, err
if !relayInfo.IsPlayground {
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return 0, err
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
userQuota, err = GetUserQuota(token.UserId)
userQuota, err = GetUserQuota(relayInfo.UserId)
if err != nil {
return 0, err
}
if userQuota < quota {
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
}
err = DecreaseTokenQuota(tokenId, quota)
if err != nil {
return 0, err
if !relayInfo.IsPlayground {
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
if err != nil {
return 0, err
}
}
err = DecreaseUserQuota(token.UserId, quota)
err = DecreaseUserQuota(relayInfo.UserId, quota)
return userQuota - quota, err
}
func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
token, err := GetTokenById(tokenId)
func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)
err = DecreaseUserQuota(relayInfo.UserId, quota)
} else {
err = IncreaseUserQuota(token.UserId, -quota)
err = IncreaseUserQuota(relayInfo.UserId, -quota)
}
if err != nil {
return err
}
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)
} else {
err = IncreaseTokenQuota(tokenId, -quota)
}
if err != nil {
return err
if !relayInfo.IsPlayground {
if quota > 0 {
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
} else {
err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
}
if err != nil {
return err
}
}
if sendEmail {
@@ -284,7 +315,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(token.UserId)
email, err := GetUserEmail(relayInfo.UserId)
if err != nil {
common.SysError("failed to fetch user email: " + err.Error())
}

View File

@@ -27,7 +27,7 @@ type User struct {
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
@@ -41,6 +41,17 @@ type User struct {
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (user *User) GetAccessToken() string {
if user.AccessToken == nil {
return ""
}
return *user.AccessToken
}
func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User
@@ -218,7 +229,7 @@ func (user *User) Insert(inviterId int) error {
}
}
user.Quota = common.QuotaForNewUser
user.AccessToken = common.GetUUID()
//user.SetAccessToken(common.GetUUID())
user.AffCode = common.GetRandomString(4)
result := DB.Create(user)
if result.Error != nil {
@@ -312,11 +323,12 @@ func (user *User) ValidateAndFill() (err error) {
// that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions
password := user.Password
if user.Username == "" || password == "" {
username := strings.TrimSpace(user.Username)
if username == "" || password == "" {
return errors.New("用户名或密码为空")
}
// find buy username or email
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user)
DB.Where("username = ? OR email = ?", username, username).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
@@ -364,14 +376,6 @@ func (user *User) FillUserByWeChatId() error {
return nil
}
func (user *User) FillUserByUsername() error {
if user.Username == "" {
return errors.New("username 为空!")
}
DB.Where(User{Username: user.Username}).First(user)
return nil
}
func (user *User) FillUserByTelegramId() error {
if user.TelegramId == "" {
return errors.New("Telegram id 为空!")
@@ -384,27 +388,27 @@ func (user *User) FillUserByTelegramId() error {
}
func IsEmailAlreadyTaken(email string) bool {
return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1
}
func IsWeChatIdAlreadyTaken(wechatId string) bool {
return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
}
func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
return DB.Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
return DB.Unscoped().Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
return DB.Unscoped().Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
func IsTelegramIdAlreadyTaken(telegramId string) bool {
return DB.Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
}
func ResetUserPasswordByEmail(email string, password string) error {

View File

@@ -36,8 +36,8 @@ type AliEmbeddingRequest struct {
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
Embedding any `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {

View File

@@ -8,7 +8,6 @@ import (
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"strings"
)
const (
@@ -31,11 +30,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
} else {
a.RequestMode = RequestModeCompletion
}
a.RequestMode = RequestModeMessage
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -53,11 +48,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
var claudeReq *claude.ClaudeRequest
var err error
if a.RequestMode == RequestModeCompletion {
claudeReq = claude.RequestOpenAI2ClaudeComplete(*request)
} else {
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
}
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
c.Set("request_model", request.Model)
c.Set("converted_request", claudeReq)
return claudeReq, err

View File

@@ -222,9 +222,11 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
if resp != nil {
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
}
return nil, &usage
}

View File

@@ -50,9 +50,9 @@ type BaiduEmbeddingRequest struct {
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
Object string `json:"object"`
Embedding any `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {

View File

@@ -79,9 +79,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
}
return
}

View File

@@ -31,9 +31,9 @@ type ClaudeMessage struct {
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema InputSchema `json:"input_schema"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema map[string]interface{} `json:"input_schema"`
}
type InputSchema struct {

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -12,6 +11,8 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
func stopReasonClaude2OpenAI(reason string) string {
@@ -63,15 +64,21 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTools = append(claudeTools, Tool{
claudeTool := Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: InputSchema{
Type: params["type"].(string),
Properties: params["properties"],
Required: params["required"],
},
})
}
claudeTool.InputSchema = make(map[string]interface{})
claudeTool.InputSchema["type"] = params["type"].(string)
claudeTool.InputSchema["properties"] = params["properties"]
claudeTool.InputSchema["required"] = params["required"]
for s, a := range params {
if s == "type" || s == "properties" || s == "required" {
continue
}
claudeTool.InputSchema[s] = a
}
claudeTools = append(claudeTools, claudeTool)
}
}
@@ -102,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
}
}
formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message
lastMessage := dto.Message{
Role: "tool",
}
for i, message := range textRequest.Messages {
//if message.Role == "system" {
// if i != 0 {
// message.Role = "user"
// }
//}
if message.Role == "" {
textRequest.Messages[i].Role = "user"
}
@@ -116,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
Role: message.Role,
Content: message.Content,
}
if lastMessage != nil && lastMessage.Role == message.Role {
if message.Role == "tool" {
fmtMessage.ToolCallId = message.ToolCallId
}
if message.Role == "assistant" && message.ToolCalls != nil {
fmtMessage.ToolCalls = message.ToolCalls
}
if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
if lastMessage.IsStringContent() && message.IsStringContent() {
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
fmtMessage.Content = content
@@ -129,10 +139,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
fmtMessage.Content = content
}
formatMessages = append(formatMessages, fmtMessage)
lastMessage = &textRequest.Messages[i]
lastMessage = fmtMessage
}
claudeMessages := make([]ClaudeMessage, 0)
isFirstMessage := true
for _, message := range formatMessages {
if message.Role == "system" {
if message.IsStringContent() {
@@ -148,10 +159,54 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
claudeRequest.System = content
}
} else {
if isFirstMessage {
isFirstMessage = false
if message.Role != "user" {
// fix: first message is assistant, add user message
claudeMessage := ClaudeMessage{
Role: "user",
Content: []ClaudeMediaMessage{
{
Type: "text",
Text: "...",
},
},
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeMessage := ClaudeMessage{
Role: message.Role,
}
if message.IsStringContent() {
if message.Role == "tool" {
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
lastMessage := claudeMessages[len(claudeMessages)-1]
if content, ok := lastMessage.Content.(string); ok {
lastMessage.Content = []ClaudeMediaMessage{
{
Type: "text",
Text: content,
},
}
}
lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
})
claudeMessages[len(claudeMessages)-1] = lastMessage
continue
} else {
claudeMessage.Role = "user"
claudeMessage.Content = []ClaudeMediaMessage{
{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
},
}
}
} else if message.IsStringContent() && message.ToolCalls == nil {
claudeMessage.Content = message.StringContent()
} else {
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
@@ -184,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
}
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
}
if message.ToolCalls != nil {
for _, tc := range message.ToolCalls.([]interface{}) {
toolCallJSON, _ := json.Marshal(tc)
var toolCall dto.ToolCall
err := json.Unmarshal(toolCallJSON, &toolCall)
if err != nil {
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
continue
}
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
Type: "tool_use",
Id: toolCall.ID,
Name: toolCall.Function.Name,
Input: inputObj,
})
}
}
claudeMessage.Content = claudeMediaMessages
}
claudeMessages = append(claudeMessages, claudeMessage)
@@ -318,12 +395,13 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(tools) > 0 {
choice.Message.ToolCalls = tools
}
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice)
fullTextResponse.Choices = choices
return &fullTextResponse
}
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage
usage = &dto.Usage{}
@@ -405,7 +483,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
return nil, usage
}
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -431,15 +509,15 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
}, nil
}
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
if requestMode == RequestModeCompletion {
usage.PromptTokens = promptTokens
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = completionTokens
usage.TotalTokens = promptTokens + completionTokens
usage.TotalTokens = info.PromptTokens + completionTokens
} else {
usage.PromptTokens = claudeResponse.Usage.InputTokens
usage.CompletionTokens = claudeResponse.Usage.OutputTokens

View File

@@ -1,6 +1,7 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-3.1-8b-instruct",
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",

View File

@@ -1,7 +1,10 @@
package cohere
var ModelList = []string{
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
"command-r", "command-r-plus",
"command-r-08-2024", "command-r-plus-08-2024",
"c4ai-aya-23-35b", "c4ai-aya-23-8b",
"command-light", "command-light-nightly", "command", "command-nightly",
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
}

View File

@@ -8,6 +8,7 @@ type CohereRequest struct {
Message string `json:"message"`
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"`
SafetyMode string `json:"safety_mode,omitempty"`
}
type ChatHistory struct {

View File

@@ -23,6 +23,9 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
Stream: textRequest.Stream,
MaxTokens: textRequest.GetMaxTokens(),
}
if common.CohereSafetySetting != "NONE" {
cohereReq.SafetyMode = common.CohereSafetySetting
}
if cohereReq.MaxTokens == 0 {
cohereReq.MaxTokens = 4000
}
@@ -44,6 +47,7 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
})
}
}
return &cohereReq
}

View File

@@ -53,7 +53,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
} else if difyResponse.Event == "message" {
} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
choice.Delta.SetContentString(difyResponse.Answer)
}
response.Choices = append(response.Choices, choice)

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
@@ -25,18 +26,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
// 定义一个映射,存储模型名称和对应的版本
var modelVersionMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-ultra": "v1beta",
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName]
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
@@ -75,9 +70,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = geminiChatStreamHandler(c, resp, info)
err, usage = GeminiChatStreamHandler(c, resp, info)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = GeminiChatHandler(c, resp)
}
return
}

View File

@@ -6,7 +6,7 @@ const (
var ModelList = []string{
"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001",
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827",
}
var ChannelName = "google gemini"

View File

@@ -83,13 +83,28 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
} else {
_, format, base64String, err := common.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
continue
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
}
}
}
content.Parts = parts
@@ -205,7 +220,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
@@ -264,7 +279,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
return nil, usage
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -32,7 +32,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
@@ -58,6 +58,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = jinaEmbeddingHandler(c, resp)
}
return
}

View File

@@ -33,3 +33,28 @@ func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}
func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var jinaResp dto.OpenAIEmbeddingResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}

View File

@@ -3,24 +3,39 @@ package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type OllamaEmbeddingRequest struct {
Model string `json:"model,omitempty"`
Prompt any `json:"prompt,omitempty"`
Model string `json:"model,omitempty"`
Input []string `json:"input"`
Options *Options `json:"options,omitempty"`
}
type OllamaEmbeddingResponse struct {
Embedding []float64 `json:"embedding,omitempty"`
Embedding any `json:"embedding,omitempty"`
Error string `json:"error,omitempty"`
Model string `json:"model"`
}
//type OllamaOptions struct {
//}

View File

@@ -9,7 +9,6 @@ import (
"net/http"
"one-api/dto"
"one-api/service"
"strings"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
@@ -28,21 +27,32 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
Stop, _ = request.Stop.([]string)
}
return &OllamaRequest{
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
}
}
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "),
Model: request.Model,
Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
}
}
@@ -60,6 +70,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if ollamaEmbeddingResponse.Error != "" {
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
}
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: ollamaEmbeddingResponse.Embedding,

View File

@@ -78,6 +78,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if info.ChannelType != common.ChannelTypeOpenAI {
request.StreamOptions = nil
}
if "o1" == request.Model || strings.HasPrefix(request.Model, "o1-") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
}
return request, nil
}

View File

@@ -8,8 +8,11 @@ var ModelList = []string{
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview",
"gpt-4o", "gpt-4o-2024-05-13",
"chatgpt-4o-latest",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
"text-moderation-latest", "text-moderation-stable",

View File

@@ -22,7 +22,7 @@ import (
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
containStreamUsage := false
responseId := ""
var responseId string
var createAt int64 = 0
var systemFingerprint string
model := info.UpstreamModelName
@@ -86,7 +86,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var lastStreamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
if err == nil {
if lastStreamResponse.Usage != nil && service.ValidUsage(lastStreamResponse.Usage) {
responseId = lastStreamResponse.Id
createAt = lastStreamResponse.Created
systemFingerprint = lastStreamResponse.GetSystemFingerprint()
model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
containStreamUsage = true
usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
shouldSendLastResp = false
}
@@ -109,14 +115,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
containStreamUsage = true
}
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -133,14 +134,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
} else {
for _, streamResponse := range streamResponses {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
containStreamUsage = true
}
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
// containStreamUsage = true
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {

View File

@@ -0,0 +1,83 @@
package siliconflow
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeRerank:
err, usage = siliconflowRerankHandler(c, resp)
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,51 @@
package siliconflow
var ModelList = []string{
"THUDM/glm-4-9b-chat",
//"stabilityai/stable-diffusion-xl-base-1.0",
//"TencentARC/PhotoMaker",
"InstantX/InstantID",
//"stabilityai/stable-diffusion-2-1",
//"stabilityai/sd-turbo",
//"stabilityai/sdxl-turbo",
"ByteDance/SDXL-Lightning",
"deepseek-ai/deepseek-llm-67b-chat",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"deepseek-ai/DeepSeek-V2-Chat",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
//"stabilityai/stable-diffusion-3-medium",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-large-zh-v1.5",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/Qwen/Qwen1.5-7B-Chat",
"Pro/THUDM/glm-4-9b-chat",
"Pro/THUDM/chatglm3-6b",
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
"Pro/01-ai/Yi-1.5-6B-Chat",
"Pro/google/gemma-2-9b-it",
"Pro/internlm/internlm2_5-7b-chat",
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
"black-forest-labs/FLUX.1-schnell",
"iic/SenseVoiceSmall",
"netease-youdao/bce-embedding-base_v1",
"BAAI/bge-m3",
"internlm/internlm2_5-20b-chat",
"Qwen/Qwen2-Math-72B-Instruct",
"netease-youdao/bce-reranker-base_v1",
"BAAI/bge-reranker-v2-m3",
}
var ChannelName = "siliconflow"

View File

@@ -0,0 +1,17 @@
package siliconflow
import "one-api/dto"
type SFTokens struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type SFMeta struct {
Tokens SFTokens `json:"tokens"`
}
type SFRerankResponse struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta SFMeta `json:"meta"`
}

View File

@@ -0,0 +1,44 @@
package siliconflow
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
CompletionTokens: siliconflowResp.Meta.Tokens.OutputTokens,
TotalTokens: siliconflowResp.Meta.Tokens.InputTokens + siliconflowResp.Meta.Tokens.OutputTokens,
}
rerankResp := &dto.RerankResponse{
Results: siliconflowResp.Results,
Usage: *usage,
}
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -0,0 +1,184 @@
package vertex
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/claude"
"one-api/relay/channel/gemini"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"strings"
)
const (
RequestModeClaude = 1
RequestModeGemini = 2
RequestModeLlama = 3
)
var claudeModelMap = map[string]string{
"claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
"claude-3-opus-20240229": "claude-3-opus@20240229",
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
}
const anthropicVersion = "vertex-2023-10-16"
type Adaptor struct {
RequestMode int
AccountCredentials Credentials
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
a.RequestMode = RequestModeGemini
} else if strings.Contains(info.UpstreamModelName, "llama") {
a.RequestMode = RequestModeLlama
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
adc := &Credentials{}
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials file: %w", err)
}
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
if info.IsStream {
suffix = "streamGenerateContent?alt=sse"
} else {
suffix = "generateContent"
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
region,
adc.ProjectID,
region,
info.UpstreamModelName,
suffix,
), nil
} else if a.RequestMode == RequestModeClaude {
if info.IsStream {
suffix = "streamRawPredict?alt=sse"
} else {
suffix = "rawPredict"
}
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
info.UpstreamModelName = v
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
info.UpstreamModelName,
suffix,
), nil
} else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
region,
adc.ProjectID,
region,
), nil
}
return "", errors.New("unsupported request mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
accessToken, err := getAccessToken(a, info)
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeClaude {
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil {
return nil, err
}
vertexClaudeReq := &VertexAIClaudeRequest{
AnthropicVersion: anthropicVersion,
}
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
return nil, errors.New("failed to copy claude request")
}
c.Set("request_model", request.Model)
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
geminiRequest := gemini.CovertGemini2OpenAI(*request)
c.Set("request_model", request.Model)
return geminiRequest, nil
} else if a.RequestMode == RequestModeLlama {
return request, nil
}
return nil, errors.New("unsupported request mode")
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
case RequestModeLlama:
err, usage = openai.OaiStreamHandler(c, resp, info)
}
} else {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
case RequestModeGemini:
err, usage = gemini.GeminiChatHandler(c, resp)
case RequestModeLlama:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,15 @@
package vertex
var ModelList = []string{
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
//"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
"meta/llama3-405b-instruct-maas",
}
var ChannelName = "vertex-ai"

View File

@@ -0,0 +1,17 @@
package vertex
import "one-api/relay/channel/claude"
type VertexAIClaudeRequest struct {
AnthropicVersion string `json:"anthropic_version"`
Messages []claude.ClaudeMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}

View File

@@ -0,0 +1,16 @@
package vertex
import "one-api/common"
func GetModelRegion(other string, localModelName string) string {
// if other is json string
if common.IsJsonStr(other) {
m := common.StrToMap(other)
if m[localModelName] != nil {
return m[localModelName].(string)
} else {
return m["default"].(string)
}
}
return other
}

View File

@@ -0,0 +1,122 @@
package vertex
import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"github.com/bytedance/gopkg/cache/asynccache"
"github.com/golang-jwt/jwt"
"net/http"
"net/url"
relaycommon "one-api/relay/common"
"strings"
"fmt"
"time"
)
type Credentials struct {
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
ClientID string `json:"client_id"`
}
var Cache = asynccache.NewAsyncCache(asynccache.Options{
RefreshDuration: time.Minute * 35,
EnableExpire: true,
ExpireDuration: time.Minute * 30,
Fetcher: func(key string) (interface{}, error) {
return nil, errors.New("not found")
},
})
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
val, err := Cache.Get(cacheKey)
if err == nil {
return val.(string), nil
}
signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
if err != nil {
return "", fmt.Errorf("failed to create signed JWT: %w", err)
}
newToken, err := exchangeJwtForAccessToken(signedJWT)
if err != nil {
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
}
if err := Cache.SetDefault(cacheKey, newToken); err {
return newToken, nil
}
return newToken, nil
}
func createSignedJWT(email, privateKeyPEM string) (string, error) {
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
if block == nil {
return "", fmt.Errorf("failed to parse PEM block containing the private key")
}
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", err
}
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return "", fmt.Errorf("not an RSA private key")
}
now := time.Now()
claims := jwt.MapClaims{
"iss": email,
"scope": "https://www.googleapis.com/auth/cloud-platform",
"aud": "https://www.googleapis.com/oauth2/v4/token",
"exp": now.Add(time.Minute * 35).Unix(),
"iat": now.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedToken, err := token.SignedString(rsaPrivateKey)
if err != nil {
return "", err
}
return signedToken, nil
}
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
data := url.Values{}
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
data.Set("assertion", signedJWT)
resp, err := http.PostForm(authURL, data)
if err != nil {
return "", err
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
if accessToken, ok := result["access_token"].(string); ok {
return accessToken, nil
}
return "", fmt.Errorf("failed to get access token: %v", result)
}

View File

@@ -1,7 +1,7 @@
package zhipu_4v
var ModelList = []string{
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools",
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", "glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flash", "glm-4v-plus",
}
var ChannelName = "zhipu_4v"

View File

@@ -20,8 +20,10 @@ type RelayInfo struct {
setFirstResponse bool
ApiType int
IsStream bool
IsPlayground bool
RelayMode int
UpstreamModelName string
OriginModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
@@ -33,7 +35,7 @@ type RelayInfo struct {
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel")
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
@@ -57,17 +59,27 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
TokenUnlimited: tokenUnlimited,
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"),
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
}
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
info.RequestURLPath = "/v1" + info.RequestURLPath
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
}
if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c)
}
if info.ChannelType == common.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region")
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare {
@@ -112,7 +124,7 @@ type TaskRelayInfo struct {
}
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel")
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
@@ -140,3 +152,20 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
}
return info
}
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
return &RelayInfo{
ChannelType: info.ChannelType,
ChannelId: info.ChannelId,
TokenId: info.TokenId,
UserId: info.UserId,
Group: info.Group,
StartTime: info.StartTime,
ApiType: info.ApiType,
RelayMode: info.RelayMode,
UpstreamModelName: info.UpstreamModelName,
RequestURLPath: info.RequestURLPath,
ApiKey: info.ApiKey,
BaseUrl: info.BaseUrl,
}
}

View File

@@ -23,6 +23,8 @@ const (
APITypeDify
APITypeJina
APITypeCloudflare
APITypeSiliconFlow
APITypeVertexAi
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -66,6 +68,10 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeJina
case common.ChannelCloudflare:
apiType = APITypeCloudflare
case common.ChannelTypeSiliconFlow:
apiType = APITypeSiliconFlow
case common.ChannelTypeVertexAi:
apiType = APITypeVertexAi
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -27,6 +27,7 @@ const (
RelayModeMidjourneyModal
RelayModeMidjourneyShorten
RelayModeSwapFace
RelayModeMidjourneyUpload
RelayModeAudioSpeech // tts
RelayModeAudioTranscription // whisper
@@ -41,7 +42,7 @@ const (
func Path2RelayMode(path string) int {
relayMode := RelayModeUnknown
if strings.HasPrefix(path, "/v1/chat/completions") {
if strings.HasPrefix(path, "/v1/chat/completions") || strings.HasPrefix(path, "/pg/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") {
relayMode = RelayModeCompletions
@@ -81,6 +82,9 @@ func Path2RelayModeMidjourney(path string) int {
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
// midjourney plus
relayMode = RelayModeSwapFace
} else if strings.HasSuffix(path, "/submit/upload-discord-images") {
// midjourney plus
relayMode = RelayModeMidjourneyUpload
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasSuffix(path, "/mj/submit/blend") {

View File

@@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
}
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
@@ -87,7 +87,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
@@ -126,7 +126,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
statusCodeMappingStr := c.GetString("status_code_mapping")
if resp != nil {
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
@@ -136,7 +136,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr

View File

@@ -38,9 +38,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
@@ -50,6 +48,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
}
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
//if imageRequest.N != 1 {
// return nil, errors.New("n must be 1")
//}
@@ -121,10 +122,11 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
}
quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("image pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, quota)), "insufficient_user_quota", http.StatusBadRequest)
}
adaptor := GetAdaptor(relayInfo.ApiType)
@@ -180,7 +182,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, modelPrice, true, logContent)
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
return nil
}

View File

@@ -12,6 +12,7 @@ import (
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"strconv"
@@ -146,6 +147,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
userId := c.GetInt("id")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
var swapFaceRequest dto.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil {
@@ -191,7 +193,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
}
defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -356,6 +358,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
userId := c.GetInt("id")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
consumeQuota := true
var midjRequest dto.MidjourneyRequest
err := common.UnmarshalBodyReusable(c, &midjRequest)
@@ -382,6 +385,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionBlend
} else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionUpload
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange {
@@ -493,7 +498,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -547,7 +552,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if err != nil {
common.SysError("get_channel_null: " + err.Error())
}
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled {
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
}
}
@@ -580,7 +585,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody = []byte(newBody)
}
}
if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" {
midjourneyTask.Progress = "100%"
midjourneyTask.Status = "SUCCESS"
}
err = midjourneyTask.Insert()
if err != nil {
return &dto.MidjourneyResponse{
@@ -594,7 +602,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody)
}
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))

View File

@@ -52,7 +52,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
}
case relayconstant.RelayModeEmbeddings:
case relayconstant.RelayModeModerations:
if textRequest.Input == "" {
if textRequest.Input == "" || textRequest.Input == nil {
return nil, errors.New("field input is required")
}
case relayconstant.RelayModeEdits:
@@ -178,7 +178,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if resp != nil {
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
@@ -188,7 +188,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
@@ -247,9 +247,12 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota <= 0 || userQuota-preConsumedQuota < 0 {
if userQuota <= 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
}
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
@@ -262,17 +265,17 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
}
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
@@ -280,11 +283,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return preConsumedQuota, userQuota, nil
}
func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsumedQuota int) {
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
err := model.PostConsumeTokenQuota(relayInfo, userQuota, -preConsumedQuota, 0, false)
if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error())
}
@@ -295,7 +298,14 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
CompletionTokens: 0,
TotalTokens: relayInfo.PromptTokens,
}
extraContent += " ,(可能是请求出错)"
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
@@ -316,7 +326,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
totalTokens := promptTokens + completionTokens
var logContent string
if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f分组倍率 %.2f补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
logContent = fmt.Sprintf("模型倍率 %.2f补全倍率 %.2f分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
@@ -335,7 +345,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
//}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}

View File

@@ -16,8 +16,10 @@ import (
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_4v"
@@ -62,6 +64,10 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &jina.Adaptor{}
case constant.APITypeCloudflare:
return &cloudflare.Adaptor{}
case constant.APITypeSiliconFlow:
return &siliconflow.Adaptor{}
case constant.APITypeVertexAi:
return &vertex.Adaptor{}
}
return nil
}

View File

@@ -38,6 +38,23 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
if len(rerankRequest.Documents) == 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[rerankRequest.Model] != "" {
rerankRequest.Model = modelMap[rerankRequest.Model]
// set upstream model name
//isModelMapped = true
}
}
relayInfo.UpstreamModelName = rerankRequest.Model
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
@@ -84,7 +101,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
}
if resp != nil {
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
@@ -94,7 +111,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr

View File

@@ -111,7 +111,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
defer func(ctx context.Context) {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true)
err := model.PostConsumeTokenQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}

View File

@@ -41,6 +41,7 @@ func SetApiRouter(router *gin.Engine) {
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout)
userRoute.GET("/groups", controller.GetUserGroups)
selfRoute := userRoute.Group("/")
selfRoute.Use(middleware.UserAuth())

View File

@@ -16,6 +16,11 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter.GET("", controller.ListModels)
modelsRouter.GET("/:model", controller.RetrieveModel)
}
playgroundRouter := router.Group("/pg")
playgroundRouter.Use(middleware.UserAuth())
{
playgroundRouter.POST("/chat/completions", controller.Playground)
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{
@@ -79,5 +84,6 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney)
}
}

View File

@@ -54,6 +54,8 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
switch err.Error.Type {
case "insufficient_quota":
return true
case "insufficient_user_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
@@ -71,6 +73,15 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
} else if strings.HasPrefix(err.Error.Message, "Permission denied") {
return true
}
if strings.Contains(err.Error.Message, "The security token included in the request is invalid") { // anthropic
return true
} else if strings.Contains(err.Error.Message, "Operation not allowed") {
return true
} else if strings.Contains(err.Error.Message, "Your account is not authorized") {
return true
}
return false
}

View File

@@ -28,13 +28,11 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
// 定义一个正则表达式匹配URL
if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
lowerText := strings.ToLower(text)
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
//避免暴露内部错误
openAIError := dto.OpenAIError{
Message: text,
Type: "new_api_error",
@@ -113,14 +111,12 @@ func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskErro
func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
text := err.Error()
// 定义一个正则表达式匹配URL
if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
lowerText := strings.ToLower(text)
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
//避免暴露内部错误
taskError := &dto.TaskError{
Code: code,
Message: text,

View File

@@ -49,6 +49,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
action = constant.MjActionModal
case relayconstant.RelayModeSwapFace:
action = constant.MjActionSwapFace
case relayconstant.RelayModeMidjourneyUpload:
action = constant.MjActionUpload
case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
@@ -220,7 +222,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
var midjResponse dto.MidjourneyResponse
var midjourneyUploadsResponse dto.MidjourneyUploadResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
@@ -230,13 +232,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
}
respStr := string(responseBody)
log.Printf("responseBody: %s", respStr)
log.Printf("respStr: %s", respStr)
if respStr == "" {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
} else {
err = json.Unmarshal(responseBody, &midjResponse)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse)
if err2 != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
}
}
}
//log.Printf("midjResponse: %v", midjResponse)

View File

@@ -40,6 +40,10 @@ func InitTokenEncoders() {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4o") {
tokenEncoderMap[model] = gpt4oTokenEncoder
} else if strings.HasPrefix(model, "chatgpt-4o") {
tokenEncoderMap[model] = gpt4oTokenEncoder
} else if "o1" == model || strings.HasPrefix(model, "o1") {
tokenEncoderMap[model] = gpt4oTokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder
} else {

View File

@@ -4,8 +4,8 @@
"private": true,
"type": "module",
"dependencies": {
"@douyinfe/semi-icons": "^2.46.1",
"@douyinfe/semi-ui": "^2.55.3",
"@douyinfe/semi-icons": "^2.63.1",
"@douyinfe/semi-ui": "^2.63.1",
"@visactor/react-vchart": "~1.8.8",
"@visactor/vchart": "~1.8.8",
"@visactor/vchart-semi-theme": "~1.8.8",
@@ -22,7 +22,8 @@
"react-toastify": "^9.0.8",
"react-turnstile": "^1.0.5",
"semantic-ui-offline": "^2.5.0",
"semantic-ui-react": "^2.1.3"
"semantic-ui-react": "^2.1.3",
"sse": "github:mpetazzoni/sse.js"
},
"scripts": {
"dev": "vite",
@@ -50,7 +51,7 @@
]
},
"devDependencies": {
"@so1ve/prettier-config": "^2.0.0",
"@so1ve/prettier-config": "^3.1.0",
"@vitejs/plugin-react": "^4.2.1",
"prettier": "^3.0.0",
"typescript": "4.4.2",

4760
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -21,11 +21,12 @@ import Redemption from './pages/Redemption';
import TopUp from './pages/TopUp';
import Log from './pages/Log';
import Chat from './pages/Chat';
import Chat2Link from './pages/Chat2Link';
import { Layout } from '@douyinfe/semi-ui';
import Midjourney from './pages/Midjourney';
import Pricing from './pages/Pricing/index.js';
import Task from './pages/Task/index.js';
// import Detail from './pages/Detail';
import Playground from './components/Playground.js';
const Home = lazy(() => import('./pages/Home'));
const Detail = lazy(() => import('./pages/Detail'));
@@ -59,215 +60,232 @@ function App() {
}, []);
return (
<Layout>
<Layout.Content>
<Routes>
<Route
path='/'
element={
<>
<Routes>
<Route
path='/'
element={
<Suspense fallback={<Loading></Loading>}>
<Home />
</Suspense>
}
/>
<Route
path='/channel'
element={
<PrivateRoute>
<Channel />
</PrivateRoute>
}
/>
<Route
path='/channel/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
</Suspense>
}
/>
<Route
path='/channel/add'
element={
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
</Suspense>
}
/>
<Route
path='/token'
element={
<PrivateRoute>
<Token />
</PrivateRoute>
}
/>
<Route
path='/playground'
element={
<PrivateRoute>
<Playground />
</PrivateRoute>
}
/>
<Route
path='/redemption'
element={
<PrivateRoute>
<Redemption />
</PrivateRoute>
}
/>
<Route
path='/user'
element={
<PrivateRoute>
<User />
</PrivateRoute>
}
/>
<Route
path='/user/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}>
<EditUser />
</Suspense>
}
/>
<Route
path='/user/edit'
element={
<Suspense fallback={<Loading></Loading>}>
<EditUser />
</Suspense>
}
/>
<Route
path='/user/reset'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetConfirm />
</Suspense>
}
/>
<Route
path='/login'
element={
<Suspense fallback={<Loading></Loading>}>
<LoginForm />
</Suspense>
}
/>
<Route
path='/register'
element={
<Suspense fallback={<Loading></Loading>}>
<RegisterForm />
</Suspense>
}
/>
<Route
path='/reset'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetForm />
</Suspense>
}
/>
<Route
path='/oauth/github'
element={
<Suspense fallback={<Loading></Loading>}>
<GitHubOAuth />
</Suspense>
}
/>
<Route
path='/oauth/linuxdo'
element={
<Suspense fallback={<Loading></Loading>}>
<LinuxDoOAuth />
</Suspense>
}
/>
<Route
path='/setting'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Home />
<Setting />
</Suspense>
}
/>
<Route
path='/channel'
element={
<PrivateRoute>
<Channel />
</PrivateRoute>
}
/>
<Route
path='/channel/edit/:id'
element={
</PrivateRoute>
}
/>
<Route
path='/topup'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
<TopUp />
</Suspense>
}
/>
<Route
path='/channel/add'
element={
</PrivateRoute>
}
/>
<Route
path='/log'
element={
<PrivateRoute>
<Log />
</PrivateRoute>
}
/>
<Route
path='/detail'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
<Detail />
</Suspense>
}
/>
<Route
path='/token'
element={
<PrivateRoute>
<Token />
</PrivateRoute>
}
/>
<Route
path='/redemption'
element={
<PrivateRoute>
<Redemption />
</PrivateRoute>
}
/>
<Route
path='/user'
element={
<PrivateRoute>
<User />
</PrivateRoute>
}
/>
<Route
path='/user/edit/:id'
element={
</PrivateRoute>
}
/>
<Route
path='/midjourney'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<EditUser />
<Midjourney />
</Suspense>
}
/>
<Route
path='/user/edit'
element={
</PrivateRoute>
}
/>
<Route
path='/task'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<EditUser />
<Task />
</Suspense>
}
/>
<Route
path='/user/reset'
element={
</PrivateRoute>
}
/>
<Route
path='/pricing'
element={
<Suspense fallback={<Loading></Loading>}>
<Pricing />
</Suspense>
}
/>
<Route
path='/about'
element={
<Suspense fallback={<Loading></Loading>}>
<About />
</Suspense>
}
/>
<Route
path='/chat/:id?'
element={
<Suspense fallback={<Loading></Loading>}>
<Chat />
</Suspense>
}
/>
{/* 方便使用chat2link直接跳转聊天... */}
<Route
path='/chat2link'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<PasswordResetConfirm />
<Chat2Link />
</Suspense>
}
/>
<Route
path='/login'
element={
<Suspense fallback={<Loading></Loading>}>
<LoginForm />
</Suspense>
}
/>
<Route
path='/register'
element={
<Suspense fallback={<Loading></Loading>}>
<RegisterForm />
</Suspense>
}
/>
<Route
path='/reset'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetForm />
</Suspense>
}
/>
<Route
path='/oauth/github'
element={
<Suspense fallback={<Loading></Loading>}>
<GitHubOAuth />
</Suspense>
}
/>
<Route
path='/oauth/linuxdo'
element={
<Suspense fallback={<Loading></Loading>}>
<LinuxDoOAuth />
</Suspense>
}
/>
<Route
path='/setting'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Setting />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/topup'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<TopUp />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/log'
element={
<PrivateRoute>
<Log />
</PrivateRoute>
}
/>
<Route
path='/detail'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Detail />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/midjourney'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Midjourney />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/task'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Task />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/pricing'
element={
<Suspense fallback={<Loading></Loading>}>
<Pricing />
</Suspense>
}
/>
<Route
path='/about'
element={
<Suspense fallback={<Loading></Loading>}>
<About />
</Suspense>
}
/>
<Route
path='/chat'
element={
<Suspense fallback={<Loading></Loading>}>
<Chat />
</Suspense>
}
/>
<Route path='*' element={<NotFound />} />
</Routes>
</Layout.Content>
</Layout>
</PrivateRoute>
}
/>
<Route path='*' element={<NotFound />} />
</Routes>
</>
);
}

View File

@@ -749,7 +749,8 @@ const ChannelsTable = () => {
<Form.Select
field='group'
label='分组'
optionList={groupOptions}
optionList={[{ label: '选择分组', value: null }, ...groupOptions]}
initValue={null}
onChange={(v) => {
setSearchGroup(v);
searchChannels(searchKeyword, v, searchModel);

View File

@@ -3,7 +3,7 @@ import React, { useEffect, useState } from 'react';
import { getFooterHTML, getSystemName } from '../helpers';
import { Layout, Tooltip } from '@douyinfe/semi-ui';
const Footer = () => {
const FooterBar = () => {
const systemName = getSystemName();
const [footer, setFooter] = useState(getFooterHTML());
let remainCheckTimes = 5;
@@ -52,21 +52,17 @@ const Footer = () => {
}, []);
return (
<Layout>
<Layout.Content style={{ textAlign: 'center' }}>
{footer ? (
<Tooltip content={defaultFooter}>
<div
className='custom-footer'
dangerouslySetInnerHTML={{ __html: footer }}
></div>
</Tooltip>
) : (
defaultFooter
)}
</Layout.Content>
</Layout>
<div style={{ textAlign: 'center' }}>
{footer ? (
<div
className='custom-footer'
dangerouslySetInnerHTML={{ __html: footer }}
></div>
) : (
defaultFooter
)}
</div>
);
};
export default Footer;
export default FooterBar;

Some files were not shown because too many files have changed in this diff Show More