Compare commits

...

347 Commits

Author SHA1 Message Date
Calcium-Ion
300947f400 Merge pull request #197 from xqx333/main
Update model-ratio.go
2024-04-11 14:15:33 +08:00
xqx333
bf94893f6a Update model-ratio.go
修复gpt-4-1106-preview和gpt-4-0125-preview的输出倍率错误
2024-04-11 14:03:51 +08:00
1808837298@qq.com
4ef2422b97 update model-ratio 2024-04-10 20:12:56 +08:00
1808837298@qq.com
f188147680 feat: support gpt-4-turbo 2024-04-10 20:10:54 +08:00
Calcium-Ion
0a49715c3d Merge pull request #183 from iszcz/patch-1
清除mj prompt里的--mode
2024-04-09 00:46:47 +08:00
Calcium-Ion
89efed48fc Merge pull request #185 from h1xy/main
Fix: CompletionRatio is not working for openrouter.ai
2024-04-08 23:57:37 +08:00
Calcium-Ion
97e0aae0a7 Merge pull request #188 from Calcium-Ion/fix/many-model-error
fix: 修复渠道一次性添加很多model失败
2024-04-08 23:56:45 +08:00
Xyfacai
320da09f36 fix: 修复渠道一次性添加很多model失败
修复渠道一次性添加很多model并且group多
提示失败 too many SQL variables
2024-04-08 23:51:51 +08:00
CaIon
2d849e0dd6 fix: 307本地重试 2024-04-08 14:10:09 +08:00
CaIon
60d7ed3fb5 fix: distributor panic 2024-04-08 13:48:36 +08:00
h1xy
c5f6d0e063 Fix: CompletionRatio is not working for openrouter.ai
https://openrouter.ai/docs#models
Model name of openrouter is prefix with company name, e.g. "model": "anthropic/claude-3-opus:beta", therefore, CompletionRatio will not working for it which is only work for prefix with claude-xxx
2024-04-08 02:12:47 +08:00
CaIon
a7cfce24d0 feat: automatically ban channels that exceeded quota 2024-04-07 22:22:27 +08:00
CaIon
34bf8f8945 fix: select channel 2024-04-07 22:08:11 +08:00
CaIon
2d1d1b4631 update go-epay 2024-04-07 14:42:03 +08:00
iszcz
5961de03e7 清除--mode 2024-04-06 23:08:50 +08:00
CaIon
fbdb17022c update README.md 2024-04-06 20:49:34 +08:00
CaIon
497cc32634 update README.md 2024-04-06 20:47:16 +08:00
CaIon
462c328d4b feat: 支持未开启缓存下本地重试 2024-04-06 20:45:18 +08:00
CaIon
257cfc2390 fix: email whitelist check 2024-04-06 17:50:47 +08:00
CaIon
fed1a1d6a3 feat: 超时状态码不重试 2024-04-04 21:21:44 +08:00
CaIon
fc9f8c8e8a fix: add group tag 'unknown' 2024-04-04 21:20:54 +08:00
CaIon
f3f36dafbd chore: 优化按次计费的数据库查询次数 2024-04-04 20:10:30 +08:00
CaIon
aaf3a1f07b fix: GetRandomSatisfiedChannel 2024-04-04 19:37:33 +08:00
CaIon
c040fa229d fix bug 2024-04-04 19:18:00 +08:00
CaIon
1cd1e54be4 feat: 钱包兼容非货币形式显示额度 2024-04-04 18:21:23 +08:00
CaIon
3db64afc7f feat: 钱包兼容非货币形式显示额度 2024-04-04 18:20:38 +08:00
CaIon
bc9cfa5da0 feat: 钱包兼容非货币形式显示额度 2024-04-04 18:18:18 +08:00
CaIon
660b9b3c99 feat: able to set default test model (#138) 2024-04-04 17:29:25 +08:00
CaIon
cdf2087952 update README.md 2024-04-04 16:48:28 +08:00
CaIon
4b60528c5f feat: 本地重试 2024-04-04 16:35:44 +08:00
1808837298@qq.com
9025756b56 fix: email whitelist check 2024-04-04 12:33:11 +08:00
CaIon
2ea6009954 fix: user update error 2024-04-04 11:10:41 +08:00
CaIon
a33f685f3c fix: log page type error (close #154) 2024-04-03 23:57:49 +08:00
CaIon
3d0f77ffb6 Merge remote-tracking branch 'origin/main' 2024-04-03 23:51:32 +08:00
CaIon
5ce8e6dab6 fix: update user quote (close #161) 2024-04-03 23:51:25 +08:00
Calcium-Ion
5a5b7d618d Merge pull request #171 from QuentinHsu/perf-setting-tab-navigation
perf(Setting): setting tab navigation
2024-04-03 23:32:19 +08:00
Calcium-Ion
ad8ce915ec Merge pull request #175 from ye4293/test
修改了用户注册使用临时邮箱验证的问题
2024-04-03 23:31:50 +08:00
Calcium-Ion
456fb875de Merge pull request #176 from QuentinHsu/perf-helpers-renderGroup
refactor(helpers): renderGroup function
2024-04-03 23:31:02 +08:00
QuentinHsu
3e90b6d516 refactor(helpers): renderGroup function 2024-04-02 13:16:02 +08:00
QuentinHsu
d6e373fbe4 fix(helpers): add key prop to Tag components 2024-04-02 10:58:44 +08:00
Ghostz
224746b45a Update misc.go 2024-04-02 01:13:12 +08:00
Calcium-Ion
ac827b1862 Merge pull request #174 from AI-ASS/main 2024-04-01 19:51:02 +08:00
GAI Group
658bf2ad57 Rename .prettierrc.mjs to .prettierrc.mjs 2024-04-01 19:49:56 +08:00
Calcium-Ion
c25f48b7c5 Merge pull request #172 from MapleEve/main
Support Claude TopK
2024-04-01 18:15:45 +08:00
QuentinHsu
290dcf7587 perf(Setting): add useEffect and useNavigate hooks to Setting component 2024-04-01 16:59:07 +08:00
Maple Gao
278fd39195 feat: add Claude TopK 2024-04-01 14:33:58 +08:00
QuentinHsu
aa23c51a53 perf(Setting): add tabActiveKey state to Setting component 2024-04-01 13:33:57 +08:00
Calcium-Ion
87919b032d Merge pull request #167 from weikecloud/main
增加MJ上游构图失败判断
2024-03-30 16:27:03 +08:00
Calcium-Ion
f7a4f18aff Update midjourney.go 2024-03-30 16:26:39 +08:00
余生一个白恩
706449dede 增加上游构图失败判断 2024-03-30 13:21:05 +08:00
CaIon
36d164be0e fix: SearchUsers (close #160) 2024-03-29 22:49:08 +08:00
CaIon
d80a7d3c97 Merge remote-tracking branch 'origin/main' 2024-03-29 22:28:10 +08:00
CaIon
44a8ade4ba fix: remove sensitive check on completion (close #157) 2024-03-29 22:20:14 +08:00
Xyfacai
2cca2a989e Merge pull request #165 from xyfacai/fork/mj-mode-path
fix: 支持 /mj-{mode} 路径
2024-03-29 17:45:23 +08:00
Xiangyuan Liu
3065bf92ae fix: 支持 /mj-{mode} 路径 2024-03-29 17:45:00 +08:00
Xiangyuan Liu
2e595bdafb fix: 支持 /mj-{mode} 路径 2024-03-29 16:58:19 +08:00
Xiangyuan Liu
49df4b6eed feat: 支持 /mj-{mode} 路径 2024-03-29 16:48:50 +08:00
CaIon
5c39f54040 feat: able to set smtp ssl 2024-03-28 12:18:11 +08:00
CaIon
786ccc7da0 feat: 开启redis的情况下设置SYNC_FREQUENCY默认为60 2024-03-26 23:00:04 +08:00
CaIon
8eedad9470 feat: support ollama embedding 2024-03-26 19:53:53 +08:00
CaIon
319e97d677 fix: ollama channel test 2024-03-26 19:27:11 +08:00
CaIon
6114c9bb96 fix: CountTokenInput 2024-03-26 18:49:53 +08:00
CaIon
3cf2f0d5cb fix: CountTokenInput 2024-03-26 18:21:38 +08:00
CaIon
2a345ae070 ci: update ci 2024-03-25 22:55:33 +08:00
CaIon
d8c91fa448 feat: 进一步防止暴露数上游以及数据库地址 2024-03-25 22:54:15 +08:00
CaIon
cc8cc8b386 fix: try to fix 307 2024-03-25 22:51:31 +08:00
CaIon
1587ea565b feat: support gemini-1.5 2024-03-25 22:33:46 +08:00
CaIon
a7a1fc615d feat: remove azure model TrimPrefix 2024-03-25 22:33:33 +08:00
CaIon
b2a280c1ec fix: 无法复制弹窗过小 2024-03-25 16:49:53 +08:00
CaIon
f1fb7b32a3 chore: update model ratio 2024-03-25 16:17:35 +08:00
CaIon
3800dc219e fix: Cannot read properties of undefined (reading 'map') (close #148) 2024-03-25 14:11:28 +08:00
CaIon
72962e988f Merge remote-tracking branch 'origin/main' 2024-03-25 13:52:37 +08:00
Calcium-Ion
01e3acfada Merge pull request #145 from QuentinHsu/fix-dev-error
fix(global): error in console under dev mode
2024-03-25 13:52:22 +08:00
QuentinHsu
f671176da0 fix(global): error in console under dev mode 2024-03-24 18:50:21 +08:00
CaIon
2d36dee17c fix: 流模式网络错误导致0补 2024-03-23 23:52:04 +08:00
CaIon
6eb30ec3e6 fix: 模型倍率和价格无法设置 2024-03-23 23:24:17 +08:00
CaIon
0b3520e3c8 fix: 修复默认模型倍率未显示 2024-03-23 23:21:14 +08:00
CaIon
63304a5b2d Merge remote-tracking branch 'origin/main' 2024-03-23 21:49:13 +08:00
CaIon
66e30f4115 fix: ci 2024-03-23 21:47:51 +08:00
Calcium-Ion
0618f03c68 Merge pull request #141 from Calcium-Ion/vite-support
feat: vite
2024-03-23 21:42:56 +08:00
CaIon
962dc984f4 chore: lint fix 2024-03-23 21:24:39 +08:00
CaIon
15e7307320 feat: prettier 2024-03-23 21:23:39 +08:00
CaIon
951383c371 chore: delete useless file 2024-03-23 21:02:38 +08:00
CaIon
87b6210045 chore: delete useless dir 2024-03-23 21:00:23 +08:00
CaIon
525fc1b3b7 feat: 从本地读取字体 (close #130) 2024-03-23 20:57:52 +08:00
CaIon
58f2cf3a79 feat: 首页加载速度优化 2024-03-23 20:22:00 +08:00
CaIon
06c86397e1 chore: Chunking Strategy 2024-03-23 19:37:19 +08:00
CaIon
21f48b55e0 fix: embed 2024-03-23 19:27:18 +08:00
CaIon
f823b4d4d8 update Dockerfile 2024-03-23 19:18:28 +08:00
CaIon
93be61aaf3 feat: vite 2024-03-23 19:09:09 +08:00
Calcium-Ion
a500097b36 Merge pull request #137 from MapleEve/main
feat: support 01.AI
2024-03-23 17:38:53 +08:00
CaIon
67332bc8df fix: 模型固定价格为空时错误使用默认价格 2024-03-23 17:19:29 +08:00
CaIon
d0acecb2ab fix: GLM-4V 的 Vision 兼容问题 (close #136) 2024-03-23 17:08:34 +08:00
Maple Gao
a825699e9a Merge branch 'Calcium-Ion:main' into main 2024-03-23 01:10:39 +08:00
CaIon
a70ca53449 fix: mj 2024-03-22 21:39:44 +08:00
CaIon
c33b1522cc fix: 充值并发导致订单号相同 2024-03-21 23:57:48 +08:00
CaIon
ff7da08bad fix: add missing created 2024-03-21 23:46:43 +08:00
CaIon
3e03c5a742 fix: add missing id,object,created 2024-03-21 23:44:39 +08:00
CaIon
d9344d79cf fix: try to fix curl: (18) 2024-03-21 23:25:07 +08:00
CaIon
c4b3d3a975 fix: fix embedding 2024-03-21 17:39:05 +08:00
CaIon
031957714a refactor: 代码结构优化 2024-03-21 17:19:21 +08:00
CaIon
3f808be254 fix: add missing version 2024-03-21 16:26:26 +08:00
CaIon
9b64f4a34a fix: fix mj panic 2024-03-21 15:04:04 +08:00
CaIon
222a55387d fix: fix SensitiveWords error 2024-03-21 14:29:56 +08:00
Maple Gao
492001a8b2 Merge branch 'Calcium-Ion:main' into main 2024-03-21 01:40:29 +08:00
CaIon
d7e25e1604 fix: fix SensitiveWords load error 2024-03-20 23:58:42 +08:00
Maple Gao
7d64f30f4d Add: 01AI in readme 2024-03-20 23:51:45 +08:00
Maple Gao
9e157ed802 fix empty url 2024-03-20 23:49:16 +08:00
Maple Gao
cfabf8a656 Add 01.AI relay 2024-03-20 23:44:03 +08:00
CaIon
c7658b70d1 fix: 敏感词库为空时,不检测 2024-03-20 22:33:22 +08:00
CaIon
d5e93e788d fix: midjourneys table 2024-03-20 21:49:54 +08:00
CaIon
dd71946047 fix: claude panic 2024-03-20 21:32:33 +08:00
CaIon
b736de7189 fix: claude panic 2024-03-20 21:28:45 +08:00
Calcium-Ion
5e7774cf08 Merge pull request #131 from Calcium-Ion/sensitive-words
feat: 敏感词过滤
2024-03-20 20:38:15 +08:00
CaIon
a232afe9fd feat: 统一错误提示 2024-03-20 20:36:55 +08:00
CaIon
eb6257a8d8 fix: fix SensitiveWordContains not working 2024-03-20 20:28:00 +08:00
CaIon
47b9f48544 fix: fix error 2024-03-20 20:26:34 +08:00
CaIon
2db4282666 feat: 保留功能 2024-03-20 20:15:32 +08:00
CaIon
64b9d3b58c feat: 初步兼容生成内容检查 2024-03-20 19:00:51 +08:00
CaIon
7a663d26ec feat: 初步兼容敏感词过滤 2024-03-20 17:07:42 +08:00
CaIon
bec21ade9d Update README.md 2024-03-18 20:44:20 +08:00
CaIon
4c4e087060 fix: api type error 2024-03-18 19:55:19 +08:00
CaIon
beadb98a8c feat: support perplexity 2024-03-18 18:09:41 +08:00
CaIon
615d109d70 Merge remote-tracking branch 'origin/main' 2024-03-18 18:08:09 +08:00
Calcium-Ion
0da49fa446 Merge pull request #125 from wozulong/main
fix: openai_organization not working
2024-03-18 18:07:34 +08:00
CaIon
578b5f6536 feat: support perplexity 2024-03-18 18:00:19 +08:00
我秦始皇
f63ad9c03c fix: make the 'openai_organization' parameter actually work. 2024-03-18 02:19:59 -07:00
我秦始皇
0fb98e44a7 fix: make the 'openai_organization' parameter actually work. 2024-03-18 02:11:21 -07:00
CaIon
652bb4a53c Update README.md 2024-03-18 00:16:32 +08:00
CaIon
a26b9a9bff Update Midjourney.md 2024-03-18 00:13:37 +08:00
CaIon
f82aa956bd feat: midjourneys表添加索引 2024-03-17 22:56:17 +08:00
CaIon
b8c053c37f fix: chatglm top_p error (close #124) 2024-03-17 22:55:29 +08:00
CaIon
2581b37394 fix: 侧边栏聚焦问题 (close #123) 2024-03-17 16:37:19 +08:00
CaIon
f65477d054 fix: 修复日志翻页问题 (close #122) 2024-03-17 16:18:45 +08:00
CaIon
4b952b8582 feat: print midjourney-proxy request error 2024-03-16 17:25:10 +08:00
CaIon
a6ba1d01d9 feat: 添加回调未开启提示 2024-03-15 22:15:16 +08:00
CaIon
2d2fec24d0 chore: update dependency 2024-03-15 18:35:15 +08:00
CaIon
d34b55c154 chore: reformat code 2024-03-15 16:05:33 +08:00
CaIon
25ec99913b feat: Add logs page-size 2024-03-15 15:07:14 +08:00
CaIon
7f22d58574 feat: Only update weight and priority when blur 2024-03-15 14:53:33 +08:00
CaIon
dfdeadf1a5 feat: Remember and automatically set page-size (close #118) 2024-03-15 14:45:52 +08:00
CaIon
6ab1b3a524 fix: fix image-seed error 2024-03-14 22:17:03 +08:00
CaIon
4917e5a92f feat: 允许开关mj回调 2024-03-14 21:21:37 +08:00
Calcium-Ion
f62dcbf669 Merge pull request #114 from Calcium-Ion/midjourney-proxy-plus
feat: support midjourney-proxy-plus
2024-03-14 18:49:30 +08:00
CaIon
299911d4cd fix: 修复epay可能重复回调的问题 2024-03-14 18:33:35 +08:00
CaIon
84e0544604 refactor: 修改超时时间 2024-03-14 18:19:22 +08:00
CaIon
2786a6b539 Update README.md 2024-03-14 18:10:09 +08:00
CaIon
9b5353a81a feat: support InsightFace (close #60) 2024-03-14 18:09:57 +08:00
CaIon
bc5a54df59 feat: support image-seed (close #86) 2024-03-14 18:09:56 +08:00
CaIon
d704902b70 feat: 兼容自定义变焦,完善modal操作 2024-03-14 16:42:37 +08:00
CaIon
614220a0fb feat: 超过一小时的任务自动失败 2024-03-14 15:16:36 +08:00
CaIon
98c1f66d61 feat: support Anthropic Claude 3.0 Haiku (close #116) 2024-03-14 14:28:21 +08:00
CaIon
a77fbc0fa2 fix: reroll action error 2024-03-14 00:43:32 +08:00
CaIon
44361d75e8 fix: "Inpaint" code error 2024-03-13 23:17:12 +08:00
CaIon
9b2e5c2978 refactor: remove consumeQuota 2024-03-13 22:30:10 +08:00
CaIon
d3399d68f6 fix: fix typo 2024-03-13 22:24:02 +08:00
CaIon
3d10c9f090 feat: 将操作拆分成单独的模型 2024-03-13 21:19:48 +08:00
CaIon
d5ffaf2502 feat: 操作细分 2024-03-13 18:26:16 +08:00
CaIon
2ad591411e feat: support shorten 2024-03-13 17:46:34 +08:00
CaIon
728dbed28d feat: 兼容变焦功能 2024-03-13 16:29:27 +08:00
CaIon
fd3a41bacb feat: 请求超时处理 2024-03-13 16:19:22 +08:00
CaIon
37c0c8ebdd feat: 初步兼容midjourney-proxy-plus 2024-03-13 15:37:01 +08:00
CaIon
95d8059c90 Merge remote-tracking branch 'origin/main' 2024-03-12 18:18:57 +08:00
CaIon
c012306400 feat: add parameter top_k 2024-03-12 18:18:47 +08:00
Calcium-Ion
1e1a53e4d2 Merge pull request #113 from xyspg/main
fix: layout issues
2024-03-12 17:36:02 +08:00
Kang Jiaming
30d48ea473 fix: layout issues 2024-03-12 12:49:36 +08:00
CaIon
9e5c636490 Update README.md 2024-03-12 02:39:10 +08:00
CaIon
d53d3386e9 feat: support ollama (close #112) 2024-03-12 02:36:39 +08:00
CaIon
c3a01decd8 chore: drop idx_channels_key on start 2024-03-12 00:35:12 +08:00
CaIon
c29680b301 feat: 前端美化 2024-03-11 21:08:51 +08:00
CaIon
ac7407ce9c Update README.md 2024-03-11 19:27:54 +08:00
Calcium-Ion
3ea061a820 Merge pull request #98 from h1xy/main
Update README.md
2024-03-11 19:23:42 +08:00
CaIon
d4df1960b2 fix: 修复模型映射功能失效 (close #105) 2024-03-11 18:28:51 +08:00
CaIon
97dd80541b fix: 请求阿里通义千问异常 (close #108) 2024-03-11 14:35:46 +08:00
CaIon
bf241b218f fix: 删除用户改为软删除 (close #107) 2024-03-11 14:13:45 +08:00
CaIon
7ab6c6c303 fix: 修复claude渠道流模式计费可能异常 2024-03-09 13:25:47 +08:00
CaIon
8fe8340b6e fix: fix gemini channel test 2024-03-08 21:38:51 +08:00
CaIon
26ef906c61 fix: fix claude channel test 2024-03-08 21:38:43 +08:00
CaIon
655dfe0d09 feat: update claude default model ratio 2024-03-08 21:17:32 +08:00
CaIon
f43b268520 fix: fix claude 3 request missing the 'max_token' field 2024-03-08 21:16:12 +08:00
CaIon
37113c0e96 perf: 优化其他设置显示效果 2024-03-08 20:35:15 +08:00
CaIon
3c3c53051d fix: fix gemini 2024-03-08 20:29:04 +08:00
CaIon
6a83c8ad86 fix: 隐藏无用按钮 2024-03-08 20:14:49 +08:00
CaIon
0640dd81fd fix: fix Azure channel test (close #101) 2024-03-08 20:06:40 +08:00
CaIon
cb50fcaffe Update README.md 2024-03-08 19:51:52 +08:00
Calcium-Ion
eca48268b2 Merge pull request #103 from Calcium-Ion/dev
feat: support Claude 3
2024-03-08 19:47:24 +08:00
CaIon
4a0af1ea3c feat: support Claude 3 2024-03-08 19:43:33 +08:00
CaIon
c2965eb835 feat: support claude3 not stream 2024-03-08 18:26:18 +08:00
Calcium-Ion
4186880e4c Merge pull request #96 from QuentinHsu/refactor-other-setting
refactor(OtherSetting): change UI, improve interaction
2024-03-08 16:52:45 +08:00
1808837298@qq.com
280c63e1d4 feat: 添加充值记录按钮 2024-03-07 22:41:04 +08:00
h1xy
9e0a54e943 Update README.md
增加telegram授权登录的配置步骤。
2024-03-06 23:37:32 +08:00
1808837298@qq.com
0e06be8c3e fix: fix bug 2024-03-06 19:53:31 +08:00
QuentinHsu
931d22c96f perf(OtherSetting): code logic 2024-03-06 18:26:22 +08:00
QuentinHsu
6413bf342a refactor(OtherSetting): change UI, improve interaction 2024-03-06 17:57:53 +08:00
1808837298@qq.com
626217fbd4 fix: 修复流模式错误扣费的问题 (close #95) 2024-03-06 17:41:55 +08:00
1808837298@qq.com
9de2d21e1a fix: 恢复微信公众号扫码功能 (close #7) 2024-03-06 16:53:21 +08:00
1808837298@qq.com
3ab4f145db feat: 支持部分渠道的system角色 (close #89) 2024-03-06 14:16:04 +08:00
Calcium-Ion
da73dca9a7 Merge pull request #81 from QuentinHsu/fix-display-home-content
fix: the content displayed on the homepage is incorrect
2024-03-05 23:16:48 +08:00
Calcium-Ion
415d296171 Merge branch 'main' into fix-display-home-content 2024-03-05 23:16:35 +08:00
Calcium-Ion
d5c5c30312 Merge pull request #91 from QuentinHsu/fix-model-name-not-trimmed
fix: model name not trimmed
2024-03-05 23:08:19 +08:00
1808837298@qq.com
fb39b6b30e fix: Fix the issue of 'unknown relay mode' when making an embedding request (close #93) 2024-03-05 23:04:57 +08:00
QuentinHsu
92c1ed7f1d perf: prompt when the name of the custom model input already exists 2024-03-05 12:18:30 +08:00
QuentinHsu
d160736a49 fix: model names must not contain spaces at both ends 2024-03-05 12:16:13 +08:00
1808837298@qq.com
fe7f42fc2e feat: add "/api/status/test" 2024-03-04 19:32:59 +08:00
QuentinHsu
5cb933a278 fix: restore the set of StatusContext 2024-03-04 16:39:42 +08:00
1808837298@qq.com
912f46fcd2 feat: 记录请求的token用量 2024-03-04 14:58:30 +08:00
1808837298@qq.com
1ad1112351 feat: 解决渠道密钥长度限制问题 2024-03-04 14:55:26 +08:00
1808837298@qq.com
1f0a48a879 fix: 修复保存telegram设置后会将telegram登陆禁用的问题 2024-03-03 23:41:00 +08:00
1808837298@qq.com
d2de675564 chore: update README.md 2024-03-03 22:28:07 +08:00
Calcium-Ion
a5cfeeaa63 Merge pull request #79 from Ehco1996/telegram-login
feat: support telegram login
2024-03-03 22:08:03 +08:00
1808837298@qq.com
37b307a784 fix: 修复预扣费判定无效导致用户可无限欠费问题 2024-03-03 22:07:20 +08:00
Ehco
02d5a5f16d Merge branch 'main' into telegram-login 2024-03-03 19:42:06 +08:00
Ehco1996
699fe256d0 address comment 2024-03-03 19:41:15 +08:00
1808837298@qq.com
54088bc664 fix: 修复添加和编辑渠道无可选择模型 2024-03-03 15:53:48 +08:00
1808837298@qq.com
06b746a740 fix: remove useless code 2024-03-02 23:02:49 +08:00
1808837298@qq.com
4c43012e6c fix: Improve handling of small weights in channel selection logic 2024-03-02 22:46:26 +08:00
1808837298@qq.com
a16e949318 feat: update SparkDesk model ratio 2024-03-02 22:08:59 +08:00
1808837298@qq.com
72194bda98 feat: 可自定义支付回调地址及最低购买数量 2024-03-02 22:07:32 +08:00
1808837298@qq.com
ca5c4d2dbd Merge branch 'fix-model-name' 2024-03-02 21:16:05 +08:00
1808837298@qq.com
cbe0a1fa76 fix: add missing UpstreamModelName 2024-03-02 21:05:02 +08:00
1808837298@qq.com
d6d8ad2b7e feat: 数据看板颜色统一 2024-03-02 20:56:29 +08:00
papersnake
7297c6269f fix: add missing upstreamModelName 2024-03-02 12:26:39 +00:00
Ehco1996
cb92d6fd5f revert compose 2024-03-02 19:58:36 +08:00
Ehco
e5cea80103 Merge pull request #1 from Ehco1996/telegram-login-complete
feat: telegram login and bind
2024-03-02 17:24:34 +08:00
sljeff
84144306a8 feat: telegram login and bind 2024-03-02 17:15:52 +08:00
QuentinHsu
ac64fd26ad fix: the content displayed on the homepage is incorrect 2024-03-02 15:23:38 +08:00
1808837298@qq.com
a8f0c5dab2 feat: 可设置默认折叠侧边栏 2024-03-02 02:12:02 +08:00
1808837298@qq.com
2b076eaed2 fix: fix typo 2024-03-01 22:39:42 +08:00
Ehco1996
194ff1ac57 Update makefile and user model, add Telegram integration 2024-03-01 22:33:30 +08:00
1808837298@qq.com
84cac72a45 feat: 修复智谱GLM-4V流模式异常 2024-03-01 22:31:08 +08:00
1808837298@qq.com
413d4f0a66 feat: Add channel search by model field (close #72) 2024-03-01 21:57:52 +08:00
Ehco1996
690d2e6ba2 Add Telegram OAuth support 2024-03-01 21:54:24 +08:00
Ehco1996
140e843d93 Add Telegram bot token and name 2024-03-01 20:39:28 +08:00
1808837298@qq.com
feb40db2bc feat: "/v1/models" 只返回用户可用模型 (close #78) 2024-02-29 19:21:22 +08:00
1808837298@qq.com
b4645d1019 feat: 支持智谱GLM-4V 2024-02-29 18:31:03 +08:00
1808837298@qq.com
a3687b72f8 fix: fix preConsumeQuota error 2024-02-29 16:39:52 +08:00
1808837298@qq.com
6013219f5b feat: 初步重构完成 2024-02-29 16:21:25 +08:00
1808837298@qq.com
5b18cd6b0a feat: 初步重构 2024-02-29 01:08:18 +08:00
1808837298@qq.com
9b421478c1 fix: fix typo 2024-02-24 14:01:29 +08:00
1808837298@qq.com
569751f5db fix: 修复测试渠道指定模型时实际使用gpt-3.5-turbo的问题 (close #67) 2024-02-23 17:40:49 +08:00
1808837298@qq.com
608c945ae6 feat: 未配置在线支付时,限制用户输入金额 (close #74) 2024-02-23 17:28:00 +08:00
1808837298@qq.com
05beade3ad feat: 在线支付限制输入金额 2024-02-23 17:18:50 +08:00
1808837298@qq.com
81167c4322 feat: 记录更多的错误信息 2024-02-22 19:42:33 +08:00
1808837298@qq.com
6fa1837c24 feat: support logprobs 2024-02-22 17:52:03 +08:00
1808837298@qq.com
ba0c5bb4d9 feat: 优化令牌无效提示 2024-02-22 01:48:35 +08:00
1808837298@qq.com
5482fff62d fix: 修复redis缓存用户额度不过期的问题 2024-02-20 00:23:31 +08:00
CaIon
1cde58c089 fix: 修正 Unicode 转义序列解析问题,close #63 2024-02-02 20:02:40 +08:00
CaIon
a0c4e9f10e feat: add model gpt-3.5-turbo-0125 2024-02-02 12:46:12 +08:00
CaIon
d6d91e4340 feat: Happy New Year 2024-02-01 21:47:28 +08:00
CaIon
f19b11be8c fix: fix test channel 2024-02-01 19:27:00 +08:00
CaIon
1bfc46aa70 ci: update workflows 2024-02-01 19:02:57 +08:00
CaIon
5c8f8b4901 feat: 添加成功时自动启用通道功能, close #27 2024-02-01 18:52:39 +08:00
CaIon
affe5111cc feat: 支持测试渠道时自选模型 2024-02-01 18:31:42 +08:00
CaIon
0004c1022d fix: support AI Gateway, close #52 2024-02-01 18:11:00 +08:00
CaIon
4274b925da fix: support AI Gateway, close #52 2024-02-01 17:48:55 +08:00
CaIon
166770eebf fix: 修复是否自动禁用选项不生效 2024-02-01 17:22:26 +08:00
CaIon
6d0479632a fix: fix tool calls 2024-01-31 01:41:38 +08:00
CaIon
364d4f96c7 feat: add model text-embedding-3 2024-01-26 16:53:04 +08:00
CaIon
bce8a3e2be feat: add model gpt-4-turbo-preview 2024-01-26 16:52:16 +08:00
CaIon
4920922929 fix: fix redis error 2024-01-26 16:09:50 +08:00
CaIon
70bdb8ca90 fix: fix redis error 2024-01-26 15:52:23 +08:00
CaIon
3dc5d4d95a feat: update model-ratio 2024-01-26 14:21:10 +08:00
CaIon
ccaf64bbfc chore: check before get userGroup 2024-01-25 20:10:32 +08:00
CaIon
ac3e27859c feat: token缓存逻辑更新(实验性) 2024-01-25 20:09:06 +08:00
CaIon
e8188902c2 feat: 数据看板加入每列总计 2024-01-25 14:57:46 +08:00
CaIon
1ee8edcfd4 feat: 优化vision计费逻辑 2024-01-25 14:57:13 +08:00
CaIon
a3921ea54d feat: add midjourney price log 2024-01-25 14:56:49 +08:00
CaIon
705171e495 chore: change use_time to int 2024-01-22 17:07:18 +08:00
CaIon
b3f46223a8 feat: 请求出现 0 token的时候,加入错误提示并打印日志 2024-01-21 20:44:26 +08:00
Calcium-Ion
5b2377eea9 Merge pull request #53 from Calcium-Ion/fork/main
feat: 令牌聊天新增ChatGPT Web & Midjourney支持
2024-01-21 17:47:55 +08:00
CaIon
c59a33e8e9 feat: 令牌聊天新增ChatGPT Web & Midjourney支持 2024-01-21 17:43:40 +08:00
CaIon
e485bc7613 chore: remove unused import 2024-01-21 17:02:35 +08:00
CaIon
9bcd24fc2c feat: 美化日志详情 2024-01-21 16:20:24 +08:00
CaIon
cbdce181af feat: 日志优化逻辑,新增请求时间和是否为stream字段 2024-01-21 15:01:59 +08:00
CaIon
800f494698 Merge remote-tracking branch 'origin/main' 2024-01-20 00:13:36 +08:00
CaIon
7601610767 update package.json 2024-01-20 00:13:18 +08:00
Calcium-Ion
3a8be505c1 Merge pull request #51 from wxDadadada/main
修改数据看板文字
2024-01-17 14:29:01 +08:00
wxDadadada
51f7ad5de2 Update SiderBar.js 2024-01-17 12:48:55 +08:00
GuoRuqiang
f73a180fc3 Changes to be committed:
modified:   web/src/components/TokensTable.js
2024-01-16 18:50:01 +00:00
GuoRuqiang
e8db0a2c72 增加了一个超链聊天跳转
在“运营设置里面”增加了“聊天页面2链接”,方便将项目(https://github.com/Dooy/chatgpt-web-midjourney-proxy) 替换掉原来的AMA问天。

Changes to be committed:
    modified:   common/constants.go
    modified:   controller/misc.go
    modified:   model/option.go
    modified:   web/src/App.js
    modified:   web/src/components/OperationSetting.js
    modified:   web/src/components/TokensTable.js
2024-01-16 18:15:55 +00:00
Xyfacai
d3e070d963 fix: mj 错误处理 2024-01-14 16:17:45 +08:00
CaIon
e688e41360 perf: 美化绘画界面UI 2024-01-13 22:38:50 +08:00
CaIon
e41fcd563e chore: 数据看板限制用户只能查询跨度一个月的数据 2024-01-13 01:32:23 +08:00
CaIon
a8715c61c8 fix: 修复数据看板渲染错误 2024-01-13 01:27:50 +08:00
CaIon
00306aa142 perf: 数据看板支持选择时间粒度 2024-01-13 00:33:52 +08:00
CaIon
d30b9321b2 chore: UpdateMidjourneyTaskBulk with SafeGoroutine 2024-01-12 18:38:45 +08:00
CaIon
2b3539fcaa fix: 修复渠道一致性问题 2024-01-12 14:57:03 +08:00
CaIon
e2a1caba4c fix: 修复渠道一致性问题 2024-01-12 14:36:15 +08:00
CaIon
febcadb42c chore: add SafeGoroutine 2024-01-12 13:49:53 +08:00
CaIon
075b1ac113 fix: delete RelayPanicRecover 2024-01-12 13:48:13 +08:00
CaIon
6757166de5 Merge remote-tracking branch 'origin/main' 2024-01-12 13:46:08 +08:00
CaIon
2a9c3ac6af fix: 修复mj错误返还费用问题 2024-01-12 13:45:52 +08:00
Calcium-Ion
2ccd6c04e0 Merge pull request #45 from AI-ASS/patch-1
fix: the Redis problem in the CacheGetUsername function
2024-01-12 13:17:20 +08:00
GAI Group
6aadcf0c78 fix: the Redis problem in the CacheGetUsername function
fix: the Redis problem in the CacheGetUsername function
2024-01-12 13:15:43 +08:00
CaIon
312417f393 update README.md 2024-01-12 00:35:22 +08:00
CaIon
ad85792818 fix: user group size 2024-01-11 22:06:30 +08:00
CaIon
521ef5e219 fix: token model limit 2024-01-11 19:18:48 +08:00
CaIon
f07b9f8ab2 chore: cache username 2024-01-11 18:40:44 +08:00
CaIon
64e9e9cc20 fix: 完善令牌预扣费逻辑 2024-01-11 14:12:48 +08:00
CaIon
701a28d0da fix: fix relay openai panic 2024-01-10 19:26:11 +08:00
CaIon
d04d2a6c4d perf: UI美化 2024-01-10 18:32:44 +08:00
CaIon
e2317524f9 perf: 美化数据看板 2024-01-10 17:49:55 +08:00
CaIon
a3b726dd82 fix: 修复高并发下,高额度用户使用低额度令牌没有预扣费的问题 2024-01-10 14:23:23 +08:00
CaIon
042d55cfd3 fix: fix response choice json 2024-01-10 13:57:49 +08:00
CaIon
3432d9e0f6 fix: do not consume user quota if failed 2024-01-10 13:56:29 +08:00
CaIon
aca8d25372 feat: able to change gemini safety setting 2024-01-10 13:56:10 +08:00
CaIon
6a24e8953f feat: able to fix channels 2024-01-10 13:23:43 +08:00
CaIon
3e13810ca2 fix: fix channel panic 2024-01-10 00:17:03 +08:00
CaIon
4a4df75830 docs: update Midjourney.md 2024-01-09 20:07:38 +08:00
CaIon
db29c96361 fix: add Tools support 2024-01-09 16:44:34 +08:00
CaIon
1618a8c9fc perf: 优化数据看板性能 2024-01-09 16:20:04 +08:00
CaIon
75b6327f4f feat: support Azure dall-e 2024-01-09 15:46:45 +08:00
CaIon
fb95216b5a chore: delete model price log 2024-01-09 12:11:09 +08:00
CaIon
6ed365c267 fix: 修复数据看板筛选用户的时候不能指定时间的问题 2024-01-09 12:11:09 +08:00
CaIon
4f95b7d049 update README.md 2024-01-09 12:11:08 +08:00
Calcium-Ion
954ed893b9 Update LICENSE 2024-01-08 19:48:30 +08:00
CaIon
95e84f2bb1 perf: 优化数据看板性能和显示效果 2024-01-08 18:49:10 +08:00
CaIon
f3f8cdc4a3 update README.md 2024-01-08 16:26:00 +08:00
CaIon
1244963e81 feat: 可设置令牌能调用的模型 2024-01-08 16:25:17 +08:00
CaIon
8f36a995ef perf: 美化数据看板 2024-01-08 11:32:27 +08:00
CaIon
daf0be4915 fix: fix JSON tag in Log struct 2024-01-08 11:25:22 +08:00
CaIon
33b93e1e9a 删除前端无用log 2024-01-08 11:24:47 +08:00
CaIon
f3124e7252 fix: 减少username和model_name字段长度为64 2024-01-08 10:11:36 +08:00
CaIon
ebedf2cb7e fix: 减少group和model字段长度为64 2024-01-08 10:08:20 +08:00
Calcium-Ion
d68a50125b Update README.md 2024-01-08 00:25:12 +08:00
CaIon
18e9d37057 fix: 修复侧边导航栏需要刷新才出现选项的问题 2024-01-07 23:54:51 +08:00
CaIon
e2d994d73a fix: 修复mj固定价格设置无效的问题 2024-01-07 23:40:06 +08:00
CaIon
29dbdf01f0 fix: 数据面板次数统计错误 2024-01-07 23:28:27 +08:00
CaIon
64862cd634 fix: recover panic 2024-01-07 22:25:03 +08:00
CaIon
a5e0f86c65 fix: 索引长度 2024-01-07 22:16:20 +08:00
CaIon
2a995a5da2 fix: 索引名称长度 2024-01-07 22:07:17 +08:00
CaIon
bba6174745 fix: fix gemini panic 2024-01-07 21:27:28 +08:00
CaIon
b91a269ddb fix: 修复数据看板饼图显示错误 2024-01-07 20:48:09 +08:00
CaIon
1c2bba8979 feat: 完善数据看板选择时间区间 2024-01-07 19:47:35 +08:00
CaIon
ce05e7dd86 fix: 优化数据看板更新逻辑 2024-01-07 18:33:41 +08:00
CaIon
bf8794d257 feat: 新增数据看板 2024-01-07 18:31:14 +08:00
CaIon
c09df83f34 feat: 允许关闭绘图选项 2024-01-07 18:31:04 +08:00
CaIon
79b1201f47 fix: 修复兑换码复制带sk- 2024-01-06 21:49:11 +08:00
CaIon
b3fc335908 修复mj计费问题 2024-01-05 12:23:26 +08:00
Calcium-Ion
290fed3841 Merge pull request #41 from Calcium-Ion/optimize/mj
optimize: MJ 部分调整、优化
2024-01-03 22:45:23 +08:00
Calcium-Ion
aa6d623981 Merge pull request #42 from Tailen/main
feat: 添加 davinci-002 和 babbage-002 模型
2024-01-03 22:43:28 +08:00
Hao Mei
afbd585048 feat: add support for davinci-002 and babbage-002 2024-01-03 20:19:00 +08:00
Xyfacai
5c747dfee2 optimize: MJ 部分调整、优化
MJ
增加simple-change、list接口,
变换和重试操作区别出来,价格与绘图一样
优化图片返回
2024-01-01 22:46:05 +08:00
CaIon
89dd0e05a0 update README.md 2023-12-29 16:08:50 +08:00
CaIon
7e4fe14871 更新用户时刷新缓存数据 2023-12-29 00:00:23 +08:00
CaIon
85411bc167 update README.md 2023-12-28 23:59:57 +08:00
CaIon
b86bc169a5 无效的请求返回具体原因 2023-12-28 23:59:46 +08:00
198 changed files with 17646 additions and 11566 deletions

View File

@@ -1,4 +1,4 @@
name: Publish Docker image (amd64, English)
name: Publish Docker image (amd64)
on:
push:
@@ -24,21 +24,26 @@ jobs:
run: |
git describe --tags > VERSION
- name: Translate
run: |
python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: |
justsong/one-api-en
calciumion/new-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images
uses: docker/build-push-action@v3

View File

@@ -4,7 +4,6 @@ on:
push:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:
@@ -49,7 +48,7 @@ jobs:
uses: docker/metadata-action@v4
with:
images: |
justsong/one-api
calciumion/new-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images

View File

@@ -5,7 +5,7 @@ COPY web/package.json .
RUN npm install
COPY ./web .
COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2
@@ -17,7 +17,7 @@ WORKDIR /build
ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /build/build ./web/build
COPY --from=builder /build/dist ./web/dist
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine

View File

@@ -1,10 +1,6 @@
MIT License
New API
Copyright (c) 2023 CalciumIon
Based on One API
Copyright (c) 2023 JustSong
Copyright (c) 2024 Calcium-Ion
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -1,254 +1,67 @@
# Midjourney Proxy API文档
**简介**:Midjourney Proxy API文档
## 模型列表
**HOST**:https://api.nekoedu.com
### midjourney-proxy支持
- mj_imagine (绘图)
- mj_variation (变换)
- mj_reroll (重绘)
- mj_blend (混合)
- mj_upscale (放大)
- mj_describe (图生文)
**Version**:v2.3.5
### 仅midjourney-proxy-plus支持
- mj_zoom (比例变焦)
- mj_shorten (提示词缩短)
- mj_modal (窗口提交局部重绘和自定义比例变焦必须和mj_modal一同添加)
- mj_inpaint (局部重绘提交必须和mj_modal一同添加)
- mj_custom_zoom (自定义比例变焦必须和mj_modal一同添加)
- mj_high_variation (强变换)
- mj_low_variation (弱变换)
- mj_pan (平移)
- swap_face (换脸)
[TOC]
# 任务提交
## 绘图变化
**接口地址**:`/mj/submit/change`
**请求方式**:`POST`
**请求数据类型**:`application/json`
**响应数据类型**:`*/*`
**接口描述**:
**请求示例**:
```javascript
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
```json
{
"action": "UPSCALE",
"index": 1,
"notifyHook": "",
"state": "",
"taskId": "1320098173412546"
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05
}
```
其中mj_inpaint和mj_custom_zoom的价格设置为0是因为这两个模型需要搭配mj_modal使用所以价格由mj_modal决定。
## 渠道设置
**请求参数**:
### 对接 midjourney-proxy(plus)
1.
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
| -------- | -------- | ----- | -------- | -------- | ------ |
|changeDTO|changeDTO|body|true|变化任务提交参数|变化任务提交参数|
|  action|UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL||true|string||
|  index|序号(1~4), action为UPSCALE,VARIATION时必传||false|integer(int32)||
|  notifyHook|回调地址, 为空时使用全局notifyHook||false|string||
|  state|自定义参数||false|string||
|  taskId|任务ID||true|string||
部署Midjourney-Proxy并配置好midjourney账号等强烈建议设置密钥[项目地址](https://github.com/novicezk/midjourney-proxy)
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**如果是plus版本选择**Midjourney Proxy Plus**
,模型请参考上方模型列表
3. 地址填写midjourney-proxy部署的地址例如http://localhost:8080
4. 密钥填写midjourney-proxy的密钥如果没有设置密钥可以随便填
**响应状态**:
### 对接上游new api
| 状态码 | 说明 | schema |
| -------- | -------- | ----- |
|200|OK|提交结果|
|201|Created||
|401|Unauthorized||
|403|Forbidden||
|404|Not Found||
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
| -------- | -------- | ----- |----- |
|code|状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)|integer(int32)|integer(int32)|
|description|描述|string||
|properties|扩展字段|object||
|result|任务ID|string||
**响应示例**:
```javascript
{
"code": 1,
"description": "提交成功",
"properties": {},
"result": 1320098173412546
}
```
## 提交Imagine任务
**接口地址**:`/mj/submit/imagine`
**请求方式**:`POST`
**请求数据类型**:`application/json`
**响应数据类型**:`*/*`
**接口描述**:
**请求示例**:
```javascript
{
"base64": "",
"notifyHook": "",
"prompt": "Cat",
"state": ""
}
```
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
| -------- | -------- | ----- | -------- | -------- | ------ |
|imagineDTO|imagineDTO|body|true|Imagine提交参数|Imagine提交参数|
|  base64|垫图base64||false|string||
|  notifyHook|回调地址, 为空时使用全局notifyHook||false|string||
|  prompt|提示词||true|string||
|  state|自定义参数||false|string||
**响应状态**:
| 状态码 | 说明 | schema |
| -------- | -------- | ----- |
|200|OK|提交结果|
|201|Created||
|401|Unauthorized||
|403|Forbidden||
|404|Not Found||
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
| -------- | -------- | ----- |----- |
|code|状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)|integer(int32)|integer(int32)|
|description|描述|string||
|properties|扩展字段|object||
|result|任务ID|string||
**响应示例**:
```javascript
{
"code": 1,
"description": "提交成功",
"properties": {},
"result": 1320098173412546
}
```
# 任务查询
## 指定ID获取任务
**接口地址**:`/mj/task/{id}/fetch`
**请求方式**:`GET`
**请求数据类型**:`application/x-www-form-urlencoded`
**响应数据类型**:`*/*`
**接口描述**:
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
| -------- | -------- | ----- | -------- | -------- | ------ |
|id|任务ID|path|false|string||
**响应状态**:
| 状态码 | 说明 | schema |
| -------- | -------- | ----- |
|200|OK|任务|
|401|Unauthorized||
|403|Forbidden||
|404|Not Found||
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
| -------- | -------- | ----- |----- |
|action|可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND|string||
|description|任务描述|string||
|failReason|失败原因|string||
|finishTime|结束时间|integer(int64)|integer(int64)|
|id|任务ID|string||
|imageUrl|图片url|string||
|progress|任务进度|string||
|prompt|提示词|string||
|promptEn|提示词-英文|string||
|startTime|开始执行时间|integer(int64)|integer(int64)|
|state|自定义参数|string||
|status|任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS|string||
|submitTime|提交时间|integer(int64)|integer(int64)|
**响应示例**:
```javascript
{
"action": "",
"description": "",
"failReason": "",
"finishTime": 0,
"id": "",
"imageUrl": "",
"progress": "",
"prompt": "",
"promptEn": "",
"startTime": 0,
"state": "",
"status": "",
"submitTime": 0
}
```
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
2. 地址填写上游new api的地址例如http://localhost:3000
3. 密钥填写上游new api的密钥

View File

@@ -14,26 +14,60 @@
> 最新版Docker镜像 calciumion/new-api:latest
> 更新指令 docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
## 此分叉版本的主要变更
## 主要变更
此分叉版本的主要变更如下:
1. 全新的UI界面部分界面还待更新
2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持:
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持[对接文档](Midjourney.md),支持的接口如下
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
+ [x] /mj/submit/action 仅midjourney-proxy-plus支持下同
+ [x] /mj/submit/modal
+ [x] /mj/submit/shorten
+ [x] /mj/task/{id}/image-seed
+ [x] /mj/insight-face/swap InsightFace
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付
+ [x] 易支付
4. 支持用key查询使用额度:
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用情况,方便二次分销
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
5. 渠道显示已使用额度,支持指定组织访问
6. 分页支持选择每页显示数量
7. 支持 gpt-4-1106-vision-previewdall-e-3tts-1
8. 支持第三方模型 **gps** gpt-4-gizmo-*在渠道中添加自定义模型gpt-4-gizmo-*即可
9. 兼容原版One API的数据库可直接使用原版数据库one-api.db
10. 支持模型按次数收费,可在 系统设置-运营设置 中设置
11. 支持gemini-pro模型
7. 兼容原版One API的数据库可直接使用原版数据库one-api.db
8. 支持模型按次数收费,可在 系统设置-运营设置 中设置
9. 支持渠道**加权随机**
10. 数据看板
11. 可设置令牌能调用的模型
12. 支持Telegram授权登录。
1. 系统设置-配置登录注册-允许通过Telegram登录
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
3. 选择你的bot然后输入http(s)://你的网站地址/login
4. Telegram Bot 名称是bot username 去掉@后的字符串
## 模型支持
此版本额外支持以下模型:
1. 第三方模型 **gps** gpt-4-gizmo-*
2. 智谱glm-4vglm-4v识图
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
6. [零一万物](https://platform.lingyiwanwu.com/)
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
## 部署
### 基于 Docker 进行部署
@@ -44,11 +78,25 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -
# 例如:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
### 使用宝塔面板Docker功能部署
```shell
# 使用 SQLite 的部署命令:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
# 例如:
# 注意数据库要开启远程访问并且只允许服务器IP访问
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 注意数据库要开启远程访问并且只允许服务器IP访问
```
## Midjourney接口设置文档
[对接文档](Midjourney.md)
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="500">
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
## 界面截图
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/d1ac216e-0804-4105-9fdc-66b35022d861)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
@@ -56,8 +104,17 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/90d7d763-6a77-4b36-9f76-2bb30f18583d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/e414228a-3c35-429a-b298-6451d76d9032)
夜间模式
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/5b3228e8-2556-44f7-97d6-4f8d8ee6effa)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## 相关项目
- [One API](https://github.com/songquanpeng/one-api):原版项目
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)Midjourney接口支持
- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代 AI 一站式 B/C 端解决方案
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)用key查询使用额度
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date)

View File

@@ -9,21 +9,32 @@ import (
"github.com/google/uuid"
)
// Pay Settings
var PayAddress = ""
var CustomCallbackAddress = ""
var EpayId = ""
var EpayKey = ""
var Price = 7.3
var MinTopUp = 1
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API"
var ServerAddress = "http://localhost:3000"
var PayAddress = ""
var EpayId = ""
var EpayKey = ""
var Price = 7.3
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var ChatLink2 = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var DrawingEnabled = true
var DataExportEnabled = true
var DataExportInterval = 5 // unit: minute
var DataExportDefaultTime = "hour" // unit: minute
var DefaultCollapseSidebar = false // default value of collapse sidebar
// Any options with "Secret", "Token" in its key won't be return by GetOptions
@@ -40,10 +51,12 @@ var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var WeChatAuthEnabled = false
var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
@@ -63,6 +76,7 @@ var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPSSLEnabled = false
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
@@ -77,11 +91,15 @@ var WeChatAccountQRCodeImageURL = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var TelegramBotToken = ""
var TelegramBotName = ""
var QuotaForNewUser = 0
var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
@@ -94,13 +112,15 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
@@ -168,10 +188,10 @@ const (
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeAPI2D = 2
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeCloseAI = 4
ChannelTypeOpenAISB = 5
ChannelTypeOllama = 4
ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
@@ -191,6 +211,10 @@ const (
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
)
var ChannelBaseURLs = []string{
@@ -198,7 +222,7 @@ var ChannelBaseURLs = []string{
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
@@ -218,5 +242,12 @@ var ChannelBaseURLs = []string{
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"", //24
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
"https://api.perplexity.ai", //27
"", //28
"", //29
"", //30
"https://api.lingyiwanwu.com", //31
}

View File

@@ -2,5 +2,6 @@ package common
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
var SQLitePath = "one-api.db?_busy_timeout=5000"

View File

@@ -24,7 +24,7 @@ func SendEmail(subject string, receiver string, content string) error {
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")
var err error
if SMTPPort == 465 {
if SMTPPort == 465 || SMTPSSLEnabled {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: SMTPServer,

View File

@@ -5,18 +5,37 @@ import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"strings"
)
func UnmarshalBodyReusable(c *gin.Context, v any) error {
const KeyRequestBody = "key_request_body"
func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(KeyRequestBody)
if requestBody != nil {
return requestBody.([]byte), nil
}
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return err
return nil, err
}
err = c.Request.Body.Close()
_ = c.Request.Body.Close()
c.Set(KeyRequestBody, requestBody)
return requestBody.([]byte), nil
}
func UnmarshalBodyReusable(c *gin.Context, v any) error {
requestBody, err := GetRequestBody(c)
if err != nil {
return err
}
err = json.Unmarshal(requestBody, &v)
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
}
if err != nil {
return err
}

32
common/go-channel.go Normal file
View File

@@ -0,0 +1,32 @@
package common
import (
"fmt"
"runtime/debug"
)
func SafeGoroutine(f func()) {
go func() {
defer func() {
if r := recover(); r != nil {
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
}
}()
f()
}()
}
func SafeSend(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = true
}
}()
// This will panic if the channel is closed.
ch <- value
// If the code reaches here, then the channel was not closed.
return false
}

View File

@@ -5,14 +5,14 @@ import (
"encoding/base64"
"errors"
"fmt"
"github.com/chai2010/webp"
"golang.org/x/image/webp"
"image"
"io"
"net/http"
"strings"
)
func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
// 去除base64数据的URL前缀如果有
if idx := strings.Index(base64String, ","); idx != -1 {
base64String = base64String[idx+1:]
@@ -22,13 +22,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
decodedData, err := base64.StdEncoding.DecodeString(base64String)
if err != nil {
fmt.Println("Error: Failed to decode base64 string")
return image.Config{}, "", err
return image.Config{}, "", "", err
}
// 创建一个bytes.Buffer用于存储解码后的数据
reader := bytes.NewReader(decodedData)
config, format, err := getImageConfig(reader)
return config, format, err
return config, format, base64String, err
}
func IsImageUrl(url string) (bool, error) {
@@ -42,6 +42,7 @@ func IsImageUrl(url string) (bool, error) {
return true, nil
}
// GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
@@ -68,17 +69,29 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err
}
defer response.Body.Close()
// 限制读取的字节数,防止下载整个图片
limitReader := io.LimitReader(response.Body, 1024*20)
//data, err := io.ReadAll(limitReader)
//if err != nil {
// log.Fatal(err)
//}
//log.Printf("%x", data)
config, format, err := getImageConfig(limitReader)
response.Body.Close()
return config, format, err
var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
// 从response.Body读取更多的数据直到达到当前的限制
additionalData := make([]byte, limit-int64(len(readData)))
n, _ := io.ReadFull(response.Body, additionalData)
readData = append(readData, additionalData[:n]...)
// 使用io.MultiReader组合已经读取的数据和response.Body
limitReader := io.MultiReader(bytes.NewReader(readData), response.Body)
var config image.Config
var format string
config, format, err = getImageConfig(limitReader)
if err == nil {
return config, format, nil
}
}
return image.Config{}, "", err // 返回最后一个错误
}
func getImageConfig(reader io.Reader) (image.Config, string, error) {

View File

@@ -3,18 +3,17 @@ package common
import (
"encoding/json"
"strings"
"time"
)
// ModelRatio
// modelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
// https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
"midjourney": 50,
var DefaultModelRatio = map[string]float64{
//"midjourney": 50,
"gpt-4-gizmo-*": 15,
"gpt-4": 15,
"gpt-4-0314": 15,
@@ -23,15 +22,21 @@ var ModelRatio = map[string]float64{
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25,
"babbage-002": 0.2, // $0.0004 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
@@ -48,42 +53,83 @@ var ModelRatio = map[string]float64{
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8,
"dall-e-3": 16,
"claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 7.143, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.018,
"yi-34b-chat-200k": 0.0864,
"yi-vl-plus": 0.0432,
}
var ModelPrice = map[string]float64{
"gpt-4-gizmo-*": 0.1,
var DefaultModelPrice = map[string]float64{
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
}
var modelPrice map[string]float64 = nil
var modelRatio map[string]float64 = nil
func ModelPrice2JSONString() string {
jsonBytes, err := json.Marshal(ModelPrice)
if modelPrice == nil {
modelPrice = DefaultModelPrice
}
jsonBytes, err := json.Marshal(modelPrice)
if err != nil {
SysError("error marshalling model price: " + err.Error())
}
@@ -91,24 +137,32 @@ func ModelPrice2JSONString() string {
}
func UpdateModelPriceByJSONString(jsonStr string) error {
ModelPrice = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &ModelPrice)
modelPrice = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPrice)
}
func GetModelPrice(name string) float64 {
func GetModelPrice(name string, printErr bool) float64 {
if modelPrice == nil {
modelPrice = DefaultModelPrice
}
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
price, ok := ModelPrice[name]
price, ok := modelPrice[name]
if !ok {
//SysError("model price not found: " + name)
if printErr {
SysError("model price not found: " + name)
}
return -1
}
return price
}
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if modelRatio == nil {
modelRatio = DefaultModelRatio
}
jsonBytes, err := json.Marshal(modelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
}
@@ -116,15 +170,18 @@ func ModelRatio2JSONString() string {
}
func UpdateModelRatioByJSONString(jsonStr string) error {
ModelRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &ModelRatio)
modelRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatio)
}
func GetModelRatio(name string) float64 {
if modelRatio == nil {
modelRatio = DefaultModelRatio
}
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
ratio, ok := ModelRatio[name]
ratio, ok := modelRatio[name]
if !ok {
SysError("model ratio not found: " + name)
return 30
@@ -134,31 +191,38 @@ func GetModelRatio(name string) float64 {
func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-3.5") {
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
return 3
}
if strings.HasSuffix(name, "1106") {
return 2
}
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
// TODO: clear this after 2023-12-11
now := time.Now()
// https://platform.openai.com/docs/models/continuous-model-upgrades
// if after 2023-12-11, use 2
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
return 2
}
}
return 1.333333
return 4.0 / 3.0
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
if strings.HasPrefix(name, "gpt-4-turbo")|| strings.HasSuffix(name, "preview") {
return 3
}
return 2
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3.38
if strings.Contains(name, "claude-instant-1") {
return 3
} else if strings.Contains(name, "claude-2") {
return 3
} else if strings.Contains(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "claude-2") {
return 2.965517
if strings.HasPrefix(name, "mistral-") {
return 3
}
if strings.HasPrefix(name, "gemini-") {
return 3
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7
}
return 1
}

View File

@@ -18,9 +18,8 @@ func InitRedisClient() (err error) {
return nil
}
if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false
SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
SysLog("SYNC_FREQUENCY not set, use default value 60")
SyncFrequency = 60
}
SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
@@ -57,12 +56,51 @@ func RedisGet(key string) (string, error) {
return RDB.Get(ctx, key).Result()
}
func RedisExpire(key string, expiration time.Duration) error {
ctx := context.Background()
return RDB.Expire(ctx, key, expiration).Err()
}
func RedisGetEx(key string, expiration time.Duration) (string, error) {
ctx := context.Background()
return RDB.GetSet(ctx, key, expiration).Result()
}
func RedisDel(key string) error {
ctx := context.Background()
return RDB.Del(ctx, key).Err()
}
func RedisDecrease(key string, value int64) error {
ctx := context.Background()
return RDB.DecrBy(ctx, key, value).Err()
// 检查键的剩余生存时间
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil {
// 失败则尝试直接减少
return RDB.DecrBy(context.Background(), key, value).Err()
}
// 如果剩余生存时间大于0则进行减少操作
if ttl > 0 {
ctx := context.Background()
// 开始一个Redis事务
txn := RDB.TxPipeline()
// 减少余额
decrCmd := txn.DecrBy(ctx, key, value)
if err := decrCmd.Err(); err != nil {
return err // 如果减少失败,则直接返回错误
}
// 重新设置过期时间,使用原来的过期时间
txn.Expire(ctx, key, ttl)
// 执行事务
_, err = txn.Exec(ctx)
return err
} else {
_ = RedisDel(key)
}
return nil
}

50
common/str.go Normal file
View File

@@ -0,0 +1,50 @@
package common
func SundaySearch(text string, pattern string) bool {
// 计算偏移表
offset := make(map[rune]int)
for i, c := range pattern {
offset[c] = len(pattern) - i
}
// 文本串长度和模式串长度
n, m := len(text), len(pattern)
// 主循环i表示当前对齐的文本串位置
for i := 0; i <= n-m; {
// 检查子串
j := 0
for j < m && text[i+j] == pattern[j] {
j++
}
// 如果完全匹配,返回匹配位置
if j == m {
return true
}
// 如果还有剩余字符,则检查下一位字符在偏移表中的值
if i+m < n {
next := rune(text[i+m])
if val, ok := offset[next]; ok {
i += val // 存在于偏移表中,进行跳跃
} else {
i += len(pattern) + 1 // 不存在于偏移表中,跳过整个模式串长度
}
} else {
break
}
}
return false // 如果没有找到匹配,返回-1
}
func RemoveDuplicate(s []string) []string {
result := make([]string, 0, len(s))
temp := map[string]struct{}{}
for _, item := range s {
if _, ok := temp[item]; !ok {
temp[item] = struct{}{}
result = append(result, item)
}
}
return result
}

View File

@@ -202,6 +202,13 @@ func GetOrDefault(env string, defaultValue int) int {
return num
}
func GetOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
@@ -223,9 +230,14 @@ func StringsContains(strs []string, str string) bool {
return false
}
// []byte only read, panic on append
// StringToByteSlice []byte only read, panic on append
func StringToByteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}
func RandomSleep() {
// Sleep for 0-3000 ms
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
}

46
constant/midjourney.go Normal file
View File

@@ -0,0 +1,46 @@
package constant
var MjNotifyEnabled = false
var MjModeClearEnabled = false
const (
MjErrorUnknown = 5
MjRequestError = 4
)
const (
MjActionImagine = "IMAGINE"
MjActionDescribe = "DESCRIBE"
MjActionBlend = "BLEND"
MjActionUpscale = "UPSCALE"
MjActionVariation = "VARIATION"
MjActionReRoll = "REROLL"
MjActionInPaint = "INPAINT"
MjActionModal = "MODAL"
MjActionZoom = "ZOOM"
MjActionCustomZoom = "CUSTOM_ZOOM"
MjActionShorten = "SHORTEN"
MjActionHighVariation = "HIGH_VARIATION"
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
)
var MidjourneyModel2Action = map[string]string{
"mj_imagine": MjActionImagine,
"mj_describe": MjActionDescribe,
"mj_blend": MjActionBlend,
"mj_upscale": MjActionUpscale,
"mj_variation": MjActionVariation,
"mj_reroll": MjActionReRoll,
"mj_modal": MjActionModal,
"mj_inpaint": MjActionInPaint,
"mj_zoom": MjActionZoom,
"mj_custom_zoom": MjActionCustomZoom,
"mj_shorten": MjActionShorten,
"mj_high_variation": MjActionHighVariation,
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
}

43
constant/sensitive.go Normal file
View File

@@ -0,0 +1,43 @@
package constant
import "strings"
var CheckSensitiveEnabled = true
var CheckSensitiveOnPromptEnabled = true
//var CheckSensitiveOnCompletionEnabled = true
// StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词
var StopOnSensitiveEnabled = true
// StreamCacheQueueLength 流模式缓存队列长度0表示无缓存
var StreamCacheQueueLength = 0
// SensitiveWords 敏感词
// var SensitiveWords []string
var SensitiveWords = []string{
"test",
}
func SensitiveWordsToString() string {
return strings.Join(SensitiveWords, "\n")
}
func SensitiveWordsFromString(s string) {
SensitiveWords = []string{}
sw := strings.Split(s, "\n")
for _, w := range sw {
w = strings.TrimSpace(w)
if w != "" {
SensitiveWords = append(SensitiveWords, w)
}
}
}
func ShouldCheckPromptSensitive() bool {
return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled
}
//func ShouldCheckCompletionSensitive() bool {
// return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
//}

View File

@@ -3,6 +3,7 @@ package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/dto"
"one-api/model"
)
@@ -27,7 +28,7 @@ func GetSubscription(c *gin.Context) {
expiredTime = 0
}
if err != nil {
openAIError := OpenAIError{
openAIError := dto.OpenAIError{
Message: err.Error(),
Type: "upstream_error",
}
@@ -69,7 +70,7 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId)
}
if err != nil {
openAIError := OpenAIError{
openAIError := dto.OpenAIError{
Message: err.Error(),
Type: "new_api_error",
}

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/service"
"strconv"
"time"
@@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
res, err := httpClient.Do(req)
res, err := service.GetHttpClient().Do(req)
if err != nil {
return nil, err
}
@@ -213,10 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
case common.ChannelTypeCloseAI:
return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB:
return updateChannelOpenAISBBalance(channel)
//case common.ChannelTypeOpenAISB:
// return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
@@ -310,7 +309,7 @@ func updateAllChannelsBalance() error {
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
disableChannel(channel.Id, channel.Name, "余额不足")
service.DisableChannel(channel.Id, channel.Name, "余额不足")
}
}
time.Sleep(common.RequestInterval)

View File

@@ -5,9 +5,17 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"one-api/common"
"one-api/dto"
"one-api/model"
"one-api/relay"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
"strconv"
"sync"
"time"
@@ -15,81 +23,98 @@ import (
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
switch channel.Type {
case common.ChannelTypePaLM:
fallthrough
case common.ChannelTypeAnthropic:
fallthrough
case common.ChannelTypeBaidu:
fallthrough
case common.ChannelTypeZhipu:
fallthrough
case common.ChannelTypeAli:
fallthrough
case common.ChannelType360:
fallthrough
case common.ChannelTypeGemini:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo"
defer func() {
if err != nil {
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
}
}()
default:
request.Model = "gpt-3.5-turbo"
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
} else {
if channel.GetBaseURL() != "" {
requestURL = channel.GetBaseURL()
}
requestURL += "/v1/chat/completions"
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
Body: nil,
Header: make(http.Header),
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
jsonData, err := json.Marshal(request)
meta := relaycommon.GenRelayInfo(c)
apiType := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
if testModel == "" {
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
testModel = adaptor.GetModelList()[0]
}
}
request := buildTestRequest()
request.Model = testModel
meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
adaptor.Init(meta, *request)
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
if err != nil {
return err, nil
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return err, nil
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
return err, nil
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if resp.StatusCode != http.StatusOK {
err := relaycommon.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
}
if usage == nil {
return errors.New("usage is nil"), nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
}
if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
func buildTestRequest() *ChatRequest {
testRequest := &ChatRequest{
func buildTestRequest() *dto.GeneralOpenAIRequest {
testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later
MaxTokens: 1,
Stream: false,
}
content, _ := json.Marshal("hi")
testMessage := Message{
testMessage := dto.Message{
Role: "user",
Content: content,
}
@@ -114,9 +139,9 @@ func TestChannel(c *gin.Context) {
})
return
}
testRequest := buildTestRequest()
testModel := c.Query("model")
tik := time.Now()
err, _ = testChannel(channel, *testRequest)
err, _ = testChannel(channel, testModel)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
@@ -140,20 +165,6 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
@@ -169,18 +180,15 @@ func testAllChannels(notify bool) error {
if err != nil {
return err
}
testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
err, openaiErr := testChannel(channel, *testRequest)
err, openaiErr := testChannel(channel, "")
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
@@ -197,8 +205,11 @@ func testAllChannels(notify bool) error {
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if shouldDisableChannel(openaiErr, -1) && ban {
disableChannel(channel.Id, channel.Name, err.Error())
if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
service.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)

View File

@@ -35,11 +35,28 @@ func GetAllChannels(c *gin.Context) {
return
}
func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
}
func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
//idSort, _ := strconv.ParseBool(c.Query("id_sort"))
channels, err := model.SearchChannels(keyword, group)
channels, err := model.SearchChannels(keyword, group, modelKeyword)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -10,9 +10,13 @@ import (
func GetAllLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 {
p = 0
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
@@ -20,7 +24,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -38,16 +42,23 @@ func GetAllLogs(c *gin.Context) {
func GetUserLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 {
p = 0
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
if pageSize > 100 {
pageSize = 100
}
userId := c.GetInt("id")
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -10,151 +10,16 @@ import (
"log"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/model"
"one-api/service"
"strconv"
"strings"
"time"
)
/*func UpdateMidjourneyTask() {
//revocer
//imageModel := "midjourney"
ctx := context.TODO()
imageModel := "midjourney"
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
for _, task := range tasks {
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
if err != nil {
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
task.FailReason = fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", task.ChannelId)
task.Status = "FAILURE"
task.Progress = "100%"
err := task.Update()
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
continue
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
//req.Header.Set("Authorization", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
continue
}
responseBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
log.Printf("responseBody: %s", string(responseBody))
var responseItem Midjourney
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
var responseWithoutStatus MidjourneyWithoutStatus
var responseStatus MidjourneyStatus
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
err2 := json.Unmarshal(responseBody, &responseStatus)
if err1 == nil && err2 == nil {
jsonData, err3 := json.Marshal(responseWithoutStatus)
if err3 != nil {
log.Printf("UpdateMidjourneyTask error1: %v", err3)
continue
}
err4 := json.Unmarshal(jsonData, &responseStatus)
if err4 != nil {
log.Printf("UpdateMidjourneyTask error2: %v", err4)
continue
}
responseItem.Status = strconv.Itoa(responseStatus.Status)
} else {
log.Printf("UpdateMidjourneyTask error3: %v", err)
continue
}
} else {
log.Printf("UpdateMidjourneyTask error4: %v", err)
continue
}
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
log.Println("error update user quota cache: " + err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 {
err := model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
log.Println("fail to increase user quota")
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
log.Printf("UpdateMidjourneyTask error5: %v", err)
}
log.Printf("UpdateMidjourneyTask success: %v", task)
cancel()
}
}
}
}
*/
func UpdateMidjourneyTaskBulk() {
//revocer
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
imageModel := "midjourney"
//imageModel := "midjourney"
ctx := context.TODO()
for {
time.Sleep(time.Duration(15) * time.Second)
@@ -167,13 +32,27 @@ func UpdateMidjourneyTaskBulk() {
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0)
for _, task := range tasks {
if task.MjId == "" {
// 统计失败的未完成任务
nullTaskIds = append(nullTaskIds, task.Id)
continue
}
taskM[task.MjId] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
}
if len(nullTaskIds) > 0 {
err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else {
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
continue
}
@@ -213,20 +92,24 @@ func UpdateMidjourneyTaskBulk() {
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
resp, err := service.GetHttpClient().Do(req)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []Midjourney
var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v", err))
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
continue
}
resp.Body.Close()
@@ -235,10 +118,16 @@ func UpdateMidjourneyTaskBulk() {
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
// 如果时间超过一小时且进度不是100%,则认为任务失败
if useTime > 3600000 && task.Progress != "100%" {
responseItem.FailReason = "上游任务超时超过1小时"
responseItem.Status = "FAILURE"
}
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
@@ -249,17 +138,23 @@ func UpdateMidjourneyTaskBulk() {
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
if responseItem.Properties != nil {
propertiesStr, _ := json.Marshal(responseItem.Properties)
task.Properties = string(propertiesStr)
}
if responseItem.Buttons != nil {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
@@ -279,7 +174,7 @@ func UpdateMidjourneyTaskBulk() {
}
}
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
if oldTask.Code != 1 {
return true
}

View File

@@ -5,35 +5,63 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
func TestStatus(c *gin.Context) {
err := model.PingDB()
if err != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"success": false,
"message": "数据库连接失败",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Server is running",
})
return
}
func GetStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress,
"price": common.Price,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"chat_link": common.ChatLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress,
"price": common.Price,
"min_topup": common.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"chat_link": common.ChatLink,
"chat_link2": common.ChatLink2,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": common.PayAddress != "" && common.EpayId != "" && common.EpayKey != "",
"mj_notify_enabled": constant.MjNotifyEnabled,
},
})
return
@@ -92,10 +120,20 @@ func SendEmailVerification(c *gin.Context) {
})
return
}
parts := strings.Split(email, "@")
if len(parts) != 2 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
localPart := parts[0]
domainPart := parts[1]
if common.EmailDomainRestrictionEnabled {
allowed := false
for _, domain := range common.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) {
if domainPart == domain {
allowed = true
break
}
@@ -103,11 +141,22 @@ func SendEmailVerification(c *gin.Context) {
if !allowed {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
"message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
})
return
}
}
if common.EmailAliasRestrictionEnabled {
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1
if containsSpecialSymbols {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
})
return
}
}
if model.IsEmailAlreadyTaken(email) {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -2,8 +2,16 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay"
"one-api/relay/channel/ai360"
"one-api/relay/channel/moonshot"
"one-api/relay/channel/lingyiwanwu"
relayconstant "one-api/relay/constant"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -53,511 +61,68 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{
{
Id: "midjourney",
for i := 0; i < relayconstant.APITypeDummy; i++ {
if i == relayconstant.APITypeAIProxyLibrary {
continue
}
adaptor := relay.GetAdaptor(i)
channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
for _, modelName := range ai360.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1677649963,
OwnedBy: "Midjourney",
Permission: permission,
Root: "midjourney",
Parent: nil,
},
{
Id: "dall-e-2",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-2",
Parent: nil,
},
{
Id: "dall-e-3",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-3",
Parent: nil,
},
{
Id: "whisper-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "whisper-1",
Parent: nil,
},
{
Id: "tts-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1",
Parent: nil,
},
{
Id: "tts-1-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-1106",
Parent: nil,
},
{
Id: "tts-1-hd",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd",
Parent: nil,
},
{
Id: "tts-1-hd-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd-1106",
Parent: nil,
},
{
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0301",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0301",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-1106",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-1106",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-instruct",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-instruct",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4",
Parent: nil,
},
{
Id: "gpt-4-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0314",
Parent: nil,
},
{
Id: "gpt-4-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0613",
Parent: nil,
},
{
Id: "gpt-4-32k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k",
Parent: nil,
},
{
Id: "gpt-4-32k-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-4-32k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0613",
Parent: nil,
},
{
Id: "gpt-4-1106-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-1106-preview",
Parent: nil,
},
{
Id: "gpt-4-vision-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-vision-preview",
Parent: nil,
},
{
Id: "gpt-4-1106-vision-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-1106-vision-preview",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
{
Id: "text-davinci-003",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-003",
Parent: nil,
},
{
Id: "text-davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-002",
Parent: nil,
},
{
Id: "text-curie-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-curie-001",
Parent: nil,
},
{
Id: "text-babbage-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-babbage-001",
Parent: nil,
},
{
Id: "text-ada-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-ada-001",
Parent: nil,
},
{
Id: "text-moderation-latest",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-latest",
Parent: nil,
},
{
Id: "text-moderation-stable",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-stable",
Parent: nil,
},
{
Id: "text-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-edit-001",
Parent: nil,
},
{
Id: "code-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "code-davinci-edit-001",
Parent: nil,
},
{
Id: "claude-instant-1",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-instant-1",
Parent: nil,
},
{
Id: "claude-2",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2",
Parent: nil,
},
{
Id: "ERNIE-Bot",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot",
Parent: nil,
},
{
Id: "ERNIE-Bot-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "ERNIE-Bot-4",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-4",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "Embedding-V1",
Parent: nil,
},
{
Id: "PaLM-2",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "PaLM-2",
Parent: nil,
},
{
Id: "gemini-pro",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro",
Parent: nil,
},
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_turbo",
Parent: nil,
},
{
Id: "chatglm_pro",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_pro",
Parent: nil,
},
{
Id: "chatglm_std",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_std",
Parent: nil,
},
{
Id: "chatglm_lite",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "qwen-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-turbo",
Parent: nil,
},
{
Id: "qwen-plus",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus",
Parent: nil,
},
{
Id: "text-embedding-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "text-embedding-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",
Created: 1677649963,
OwnedBy: "xunfei",
Permission: permission,
Root: "SparkDesk",
Parent: nil,
},
{
Id: "360GPT_S2_V9",
Object: "model",
Created: 1677649963,
Created: 1626777600,
OwnedBy: "360",
Permission: permission,
Root: "360GPT_S2_V9",
Root: modelName,
Parent: nil,
},
{
Id: "embedding-bert-512-v1",
})
}
for _, modelName := range moonshot.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Created: 1626777600,
OwnedBy: "moonshot",
Permission: permission,
Root: "embedding-bert-512-v1",
Root: modelName,
Parent: nil,
},
{
Id: "embedding_s1_v1",
})
}
for _, modelName := range lingyiwanwu.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Created: 1626777600,
OwnedBy: "lingyiwanwu",
Permission: permission,
Root: "embedding_s1_v1",
Root: modelName,
Parent: nil,
},
{
Id: "semantic_similarity_s1_v1",
})
}
for modelName, _ := range constant.MidjourneyModel2Action {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Created: 1626777600,
OwnedBy: "midjourney",
Permission: permission,
Root: "semantic_similarity_s1_v1",
Root: modelName,
Parent: nil,
},
{
Id: "hunyuan",
Object: "model",
Created: 1677649963,
OwnedBy: "tencent",
Permission: permission,
Root: "hunyuan",
Parent: nil,
},
})
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
@@ -566,6 +131,29 @@ func init() {
}
func ListModels(c *gin.Context) {
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
models := model.GetGroupModels(user.Group)
userOpenAiModels := make([]OpenAIModels, 0)
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
}
}
c.JSON(200, gin.H{
"object": "list",
"data": userOpenAiModels,
})
}
func ChannelListModels(c *gin.Context) {
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
@@ -577,7 +165,7 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
openAIError := OpenAIError{
openAIError := dto.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",

View File

@@ -1,220 +0,0 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strconv"
"strings"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
type AIProxyLibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type AIProxyLibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type AIProxyLibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type AIProxyLibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []AIProxyLibraryDocument `json:"documents"`
AIProxyLibraryError
}
type AIProxyLibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []AIProxyLibraryDocument `json:"documents"`
}
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = string(request.Messages[len(request.Messages)-1].Content)
}
return &AIProxyLibraryRequest{
Model: request.Model,
Stream: request.Stream,
Query: query,
}
}
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
if len(documents) == 0 {
return ""
}
content := "\n\n参考文档\n"
for i, document := range documents {
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
}
return content
}
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents))
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
Id: common.GetUUID(),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &stopFinishReason
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
}
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: response.Model,
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
}
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
var documents []AIProxyLibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var AIProxyLibraryResponse AIProxyLibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -1,207 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
)
var availableVoices = []string{
"alloy",
"echo",
"fable",
"onyx",
"nova",
"shimmer",
}
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
var audioRequest AudioRequest
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err := common.UnmarshalBodyReusable(c, &audioRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
} else {
audioRequest = AudioRequest{
Model: "whisper-1",
}
}
//err := common.UnmarshalBodyReusable(c, &audioRequest)
// request validation
if audioRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
if strings.HasPrefix(audioRequest.Model, "tts-1") {
if audioRequest.Voice == "" {
return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
}
if !common.StringsContains(availableVoices, audioRequest.Voice) {
return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
}
}
preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if resp.StatusCode != http.StatusOK {
return relayErrorHandler(resp)
}
var audioResponse AudioResponse
defer func(ctx context.Context) {
go func() {
quota := 0
var promptTokens = 0
if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = countAudioToken(audioRequest.Input, audioRequest.Model)
promptTokens = quota
} else {
quota = countAudioToken(audioResponse.Text, audioRequest.Model)
}
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
if strings.HasPrefix(audioRequest.Model, "tts-1") {
} else {
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -1,221 +0,0 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample uint `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
}
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return reason
}
}
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: content,
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse)
response.Id = responseId
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
completionTokens := countTokenText(claudeResponse.Completion, model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -1,212 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
)
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var imageRequest ImageRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e"
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
// Prompt validation
if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
}
if strings.Contains(imageRequest.Size, "×") {
return errorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
}
} else if imageRequest.Model == "dall-e-3" {
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
}
if imageRequest.N != 1 {
return errorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
}
}
// N should between 1 and 10
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 1
} else if imageRequest.Size == "512x512" {
sizeRatio = 1.125
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
sizeRatio = 2.5
}
qualityRatio := 1.0
if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
qualityRatio = 2.0
if imageRequest.Size == "1024×1792" || imageRequest.Size == "1792×1024" {
qualityRatio = 1.5
}
}
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if resp.StatusCode != http.StatusOK {
return relayErrorHandler(resp)
}
var textResponse ImageResponse
defer func(ctx context.Context) {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -1,506 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
type Midjourney struct {
MjId string `json:"id"`
Action string `json:"action"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}
type MidjourneyWithoutStatus struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
}
func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByMJId(taskId)
if midjourneyTask == nil {
c.JSON(400, gin.H{
"error": "midjourney_task_not_found",
})
return
}
resp, err := http.Get(midjourneyTask.ImageUrl)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed",
})
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Header("Content-Type", "image/jpeg")
//c.HeaderBar("Content-Length", string(rune(len(data))))
c.Data(http.StatusOK, "image/jpeg", data)
}
func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
var midjRequest Midjourney
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
Properties: nil,
Result: "",
}
}
midjourneyTask := model.GetByMJId(midjRequest.MjId)
if midjourneyTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "midjourney_task_not_found",
Properties: nil,
Result: "",
}
}
midjourneyTask.Progress = midjRequest.Progress
midjourneyTask.PromptEn = midjRequest.PromptEn
midjourneyTask.State = midjRequest.State
midjourneyTask.SubmitTime = midjRequest.SubmitTime
midjourneyTask.StartTime = midjRequest.StartTime
midjourneyTask.FinishTime = midjRequest.FinishTime
midjourneyTask.ImageUrl = midjRequest.ImageUrl
midjourneyTask.Status = midjRequest.Status
midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "update_midjourney_task_failed",
}
}
return nil
}
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
taskId := c.Param("id")
originTask := model.GetByMJId(taskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
var midjourneyTask Midjourney
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
midjourneyTask.State = originTask.State
midjourneyTask.SubmitTime = originTask.SubmitTime
midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" {
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
}
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
jsonMap, err := json.Marshal(midjourneyTask)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap))
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
return nil
}
func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
imageModel := "midjourney"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
var midjRequest MidjourneyRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
}
}
}
if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return &MidjourneyResponse{
Code: 4,
Description: "prompt_is_required",
}
}
midjRequest.Action = "IMAGINE"
} else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = "DESCRIBE"
} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = "BLEND"
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
originTask := model.GetByMJId(midjRequest.TaskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
} else if originTask.Action == "UPSCALE" {
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
return &MidjourneyResponse{
Code: 4,
Description: "upscale_task_can_not_be_change",
}
} else if originTask.Status != "SUCCESS" {
return &MidjourneyResponse{
Code: 4,
Description: "task_status_is_not_success",
}
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
channel, err := model.GetChannelById(originTask.ChannelId, false)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "channel_not_found",
}
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
log.Printf("检测到此操作为放大、变换获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
} else if relayMode == RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return &MidjourneyResponse{
Code: 4,
Description: "taskId_is_required",
}
} else if midjRequest.Action == "" {
return &MidjourneyResponse{
Code: 4,
Description: "action_is_required",
}
} else if midjRequest.Index == 0 {
return &MidjourneyResponse{
Code: 4,
Description: "index_can_only_be_1_2_3_4",
}
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_model_mapping_failed",
}
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
log.Printf("fullRequestURL: %s", fullRequestURL)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "marshal_text_request_failed",
}
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
if midjRequest.Action == "UPSCALE" {
sizeRatio = 0.2
}
quota := int(ratio * sizeRatio * 1000)
if consumeQuota && userQuota-quota < 0 {
return &MidjourneyResponse{
Code: 4,
Description: "quota_not_enough",
}
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "create_request_failed",
}
}
//req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//mjToken := ""
//if c.Request.Header.Get("Authorization") != "" {
// mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1]
//}
//req.Header.Set("Authorization", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
// print request header
log.Printf("request header: %s", req.Header)
log.Printf("request body: %s", midjRequest.Prompt)
resp, err := httpClient.Do(req)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
err = req.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_request_body_failed",
}
}
err = c.Request.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_request_body_failed",
}
}
var midjResponse MidjourneyResponse
defer func(ctx context.Context) {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
//if consumeQuota {
//
//}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "read_response_body_failed",
}
}
err = resp.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
err = json.Unmarshal(responseBody, &midjResponse)
log.Printf("responseBody: %s", string(responseBody))
log.Printf("midjResponse: %v", midjResponse)
if resp.StatusCode != 200 {
return &MidjourneyResponse{
Code: 4,
Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
}
}
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
// 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
// 22-排队中 {"code":22,"description":"排队中前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
// other: 提交错误description为错误描述
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: midjRequest.Action,
MjId: midjResponse.Result,
Prompt: midjRequest.Prompt,
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
StartTime: 0,
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
}
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
midjourneyTask.FailReason = midjResponse.Description
consumeQuota = false
}
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
// 将 properties 转换为一个 map
properties, ok := midjResponse.Properties.(map[string]interface{})
if ok {
imageUrl, ok1 := properties["imageUrl"].(string)
status, ok2 := properties["status"].(string)
if ok1 && ok2 {
midjourneyTask.ImageUrl = imageUrl
midjourneyTask.Status = status
if status == "SUCCESS" {
midjourneyTask.Progress = "100%"
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
midjResponse.Code = 1
}
}
}
//修改返回值
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
responseBody = []byte(newBody)
}
err = midjourneyTask.Insert()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "insert_midjourney_task_failed",
}
}
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
//修改返回值
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
err = resp.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
return nil
}

View File

@@ -1,735 +0,0 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const (
APITypeOpenAI = iota
APITypeClaude
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
)
var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
}
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
var textRequest GeneralOpenAIRequest
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
switch relayMode {
case RelayModeCompletions:
if textRequest.Prompt == "" {
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEmbeddings:
case RelayModeModerations:
if textRequest.Input == "" {
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEdits:
if textRequest.Instruction == "" {
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
isModelMapped = true
}
}
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeClaude
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
requestURL := strings.Split(requestURL, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
baseURL = c.GetString("base_url")
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := textRequest.Model
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
}
case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
}
case APITypeBaidu:
switch textRequest.Model {
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
var err error
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
}
fullRequestURL += "?access_token=" + apiKey
case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
requestBaseURL = baseURL
}
version := "v1beta"
if c.GetString("api_version") != "" {
version = c.GetString("api_version")
}
action := "generateContent"
if textRequest.Stream {
action = "streamGenerateContent"
}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
//log.Println(fullRequestURL)
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
case APITypeTencent:
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
}
var promptTokens int
var completionTokens int
switch relayMode {
case RelayModeChatCompletions:
promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model)
if err != nil {
return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
modelPrice := common.GetModelPrice(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
if modelPrice == -1 {
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
}
modelRatio = common.GetModelRatio(textRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < 0 || userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
switch apiType {
case APITypeClaude:
claudeRequest := requestOpenAI2Claude(textRequest)
jsonStr, err := json.Marshal(claudeRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu:
var jsonData []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduEmbeddingRequest)
default:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeGemini:
geminiChatRequest := requestOpenAI2Gemini(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
var jsonStr []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliEmbeddingRequest)
default:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeTencent:
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
if err != nil {
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
}
tencentRequest := requestOpenAI2Tencent(textRequest)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
jsonStr, err := json.Marshal(tencentRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
sign := getTencentSign(*tencentRequest, secretKey)
c.Request.Header.Set("Authorization", sign)
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAIProxyLibrary:
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
var req *http.Request
var resp *http.Response
isStream := textRequest.Stream
if apiType != APITypeXunfei { // cause xunfei use websocket
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
// 设置GetBody函数该函数返回一个新的io.ReadCloser该io.ReadCloser返回与原始请求体相同的数据
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(requestBody), nil
}
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
if c.Request.Header.Get("OpenAI-Organization") != "" {
req.Header.Set("OpenAI-Organization", c.Request.Header.Get("OpenAI-Organization"))
}
if channelType == common.ChannelTypeOpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
}
case APITypeClaude:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
case APITypeGemini:
req.Header.Set("Content-Type", "application/json")
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
if apiType != APITypeGemini {
// 设置公共头部...
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
}
var textResponse TextResponse
tokenName := c.GetString("token_name")
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota := 0
if modelPrice == -1 {
completionRatio := common.GetCompletionRatio(textRequest.Model)
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
// record all the consume log even if quota is 0
useTimeSeconds := time.Now().Unix() - startTime.Unix()
var logContent string
if modelPrice == -1 {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f,用时 %d秒", modelPrice, groupRatio, useTimeSeconds)
}
logModel := textRequest.Model
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
}
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
//if quota != 0 {
//
//}
}()
}(c.Request.Context())
switch apiType {
case APITypeOpenAI:
if isStream {
err, responseText := openaiStreamHandler(c, resp, relayMode)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeClaude:
if isStream {
err, responseText := claudeStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeBaidu:
if isStream {
err, usage := baiduStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypePaLM:
if textRequest.Stream { // PaLM2 API does not support stream
err, responseText := palmStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeGemini:
if textRequest.Stream {
err, responseText := geminiChatStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
} else {
err, usage := zhipuHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
}
case APITypeAli:
if isStream {
err, usage := aliStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeXunfei:
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
case APITypeAIProxyLibrary:
if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := aiProxyLibraryHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeTencent:
if isStream {
err, responseText := tencentStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := tencentHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}
}

View File

@@ -1,389 +1,163 @@
package controller
import (
"encoding/json"
"bytes"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"one-api/common"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/service"
)
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
}
type MediaMessage struct {
Type string `json:"type"`
Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"`
}
type MessageImageUrl struct {
Url string `json:"url"`
Detail string `json:"detail"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: stringContent,
})
return contentList
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
case relayconstant.RelayModeImagesGenerations:
err = relay.RelayImageHelper(c, relayMode)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:
fallthrough
case relayconstant.RelayModeAudioTranscription:
err = relay.AudioHelper(c, relayMode)
default:
err = relay.TextHelper(c)
}
var arrayContent []json.RawMessage
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "auto"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
}
}
}
return contentList
}
return nil
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeMidjourneyImagine
RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend
RelayModeMidjourneyChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
RelayModeAudio
)
// https://platform.openai.com/docs/api-reference/chat
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type AudioRequest struct {
Model string `json:"model"`
Voice string `json:"voice"`
Input string `json:"input"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens uint `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens uint `json:"max_tokens"`
//Stream bool `json:"stream"`
}
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
}
type AudioResponse struct {
Text string `json:"text,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
}
}
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
return err
}
func Relay(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode = RelayModeAudio
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
retryTimes := common.RetryTimes
requestId := c.GetString(common.RequestIdKey)
channelId := c.GetInt("channel_id")
group := c.GetString("group")
originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode)
retryLogStr := fmt.Sprintf("重试:%d", channelId)
if openaiErr != nil {
go processChannelError(c, channelId, openaiErr)
} else {
retryTimes = 0
}
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
case RelayModeAudio:
err = relayAudioHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
}
channelId = channel.Id
retryLogStr += fmt.Sprintf("->%d", channel.Id)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
openaiErr = relayHandler(c, relayMode)
if openaiErr != nil {
go processChannelError(c, channelId, openaiErr)
}
}
if err != nil {
requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
common.LogInfo(c.Request.Context(), retryLogStr)
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.StatusCode == http.StatusTooManyRequests {
//err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
}
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
}
channelId := c.GetInt("channel_id")
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message)
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
if openaiErr == nil {
return false
}
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if openaiErr.StatusCode == http.StatusTooManyRequests {
return true
}
if openaiErr.StatusCode == 307 {
return true
}
if openaiErr.StatusCode/100 == 5 {
// 超时不重试
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
return false
}
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
return false
}
if openaiErr.LocalError {
return false
}
if openaiErr.StatusCode/100 == 2 {
return false
}
return true
}
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message)
}
}
func RelayMidjourney(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
relayMode = RelayModeMidjourneyBlend
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
relayMode = RelayModeMidjourneyDescribe
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
relayMode = RelayModeMidjourneyNotify
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") {
relayMode = RelayModeMidjourneyTaskFetch
}
var err *MidjourneyResponse
relayMode := c.GetInt("relay_mode")
var err *dto.MidjourneyResponse
switch relayMode {
case RelayModeMidjourneyNotify:
err = relayMidjourneyNotify(c)
case RelayModeMidjourneyTaskFetch:
err = relayMidjourneyTask(c, relayMode)
case relayconstant.RelayModeMidjourneyNotify:
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed:
err = relay.RelayMidjourneyTaskImageSeed(c)
case relayconstant.RelayModeSwapFace:
err = relay.RelaySwapFace(c)
default:
err = relayMidjourneySubmit(c, relayMode)
err = relay.RelayMidjourneySubmit(c, relayMode)
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
c.JSON(400, gin.H{
"error": err.Description + " " + err.Result,
})
statusCode := http.StatusBadRequest
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests
}
c.JSON(statusCode, gin.H{
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
"type": "upstream_error",
"code": err.Code,
})
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result))
//if shouldDisableChannel(&err.OpenAIError) {
// channelId := c.GetInt("channel_id")
// channelName := c.GetString("channel_name")
// disableChannel(channelId, channelName, err.Result)
//};''''''''''''''''''''''''''''''''
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
}
}
func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{
err := dto.OpenAIError{
Message: "API not implemented",
Type: "new_api_error",
Param: "",
@@ -395,7 +169,7 @@ func RelayNotImplemented(c *gin.Context) {
}
func RelayNotFound(c *gin.Context) {
err := OpenAIError{
err := dto.OpenAIError{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",

116
controller/telegram.go Normal file
View File

@@ -0,0 +1,116 @@
package controller
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"io"
"one-api/common"
"one-api/model"
"sort"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
func TelegramBind(c *gin.Context) {
if !common.TelegramOAuthEnabled {
c.JSON(200, gin.H{
"message": "管理员未开启通过 Telegram 登录以及注册",
"success": false,
})
return
}
params := c.Request.URL.Query()
if !checkTelegramAuthorization(params, common.TelegramBotToken) {
c.JSON(200, gin.H{
"message": "无效的请求",
"success": false,
})
return
}
telegramId := params["id"][0]
if model.IsTelegramIdAlreadyTaken(telegramId) {
c.JSON(200, gin.H{
"message": "该 Telegram 账户已被绑定",
"success": false,
})
return
}
session := sessions.Default(c)
id := session.Get("id")
user := model.User{Id: id.(int)}
if err := user.FillUserById(); err != nil {
c.JSON(200, gin.H{
"message": err.Error(),
"success": false,
})
return
}
user.TelegramId = telegramId
if err := user.Update(false); err != nil {
c.JSON(200, gin.H{
"message": err.Error(),
"success": false,
})
return
}
c.Redirect(302, "/setting")
}
func TelegramLogin(c *gin.Context) {
if !common.TelegramOAuthEnabled {
c.JSON(200, gin.H{
"message": "管理员未开启通过 Telegram 登录以及注册",
"success": false,
})
return
}
params := c.Request.URL.Query()
if !checkTelegramAuthorization(params, common.TelegramBotToken) {
c.JSON(200, gin.H{
"message": "无效的请求",
"success": false,
})
return
}
telegramId := params["id"][0]
user := model.User{TelegramId: telegramId}
if err := user.FillUserByTelegramId(); err != nil {
c.JSON(200, gin.H{
"message": err.Error(),
"success": false,
})
return
}
setupLogin(&user, c)
}
func checkTelegramAuthorization(params map[string][]string, token string) bool {
strs := []string{}
var hash = ""
for k, v := range params {
if k == "hash" {
hash = v[0]
continue
}
strs = append(strs, k+"="+v[0])
}
sort.Strings(strs)
var imploded = ""
for _, s := range strs {
if imploded != "" {
imploded += "\n"
}
imploded += s
}
sha256hash := sha256.New()
io.WriteString(sha256hash, token)
hmachash := hmac.New(sha256.New, sha256hash.Sum(nil))
io.WriteString(hmachash, imploded)
ss := hex.EncodeToString(hmachash.Sum(nil))
return hash == ss
}

View File

@@ -124,14 +124,16 @@ func AddToken(c *gin.Context) {
return
}
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
UserId: c.GetInt("id"),
Name: token.Name,
Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
ModelLimitsEnabled: token.ModelLimitsEnabled,
ModelLimits: token.ModelLimits,
}
err = cleanToken.Insert()
if err != nil {
@@ -217,6 +219,8 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
cleanToken.ModelLimits = token.ModelLimits
}
err = cleanToken.Update()
if err != nil {

View File

@@ -2,14 +2,17 @@ package controller
import (
"fmt"
"github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
epay "github.com/star-horizon/go-epay"
"log"
"net/url"
"one-api/common"
"one-api/model"
"one-api/service"
"strconv"
"sync"
"time"
)
@@ -28,7 +31,7 @@ func GetEpayClient() *epay.Client {
if common.PayAddress == "" || common.EpayId == "" || common.EpayKey == "" {
return nil
}
withUrl, err := epay.NewClientWithUrl(&epay.Config{
withUrl, err := epay.NewClient(&epay.Config{
PartnerID: common.EpayId,
Key: common.EpayKey,
}, common.PayAddress)
@@ -38,31 +41,46 @@ func GetEpayClient() *epay.Client {
return withUrl
}
func GetAmount(count float64, user model.User) float64 {
func getPayMoney(amount float64, user model.User) float64 {
if !common.DisplayInCurrencyEnabled {
amount = amount / common.QuotaPerUnit
}
// 别问为什么用float64问就是这么点钱没必要
topupGroupRatio := common.GetTopupGroupRatio(user.Group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
amount := count * common.Price * topupGroupRatio
return amount
payMoney := amount * common.Price * topupGroupRatio
return payMoney
}
func getMinTopup() int {
minTopup := common.MinTopUp
if !common.DisplayInCurrencyEnabled {
minTopup = minTopup * int(common.QuotaPerUnit)
}
return minTopup
}
func RequestEpay(c *gin.Context) {
var req EpayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < 1 {
c.JSON(200, gin.H{"message": "充值金额不能小于1", "data": 10})
if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
amount := GetAmount(float64(req.Amount), *user)
payMoney := getPayMoney(float64(req.Amount), *user)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
var payType epay.PurchaseType
if req.PaymentMethod == "zfb" {
@@ -72,11 +90,10 @@ func RequestEpay(c *gin.Context) {
req.PaymentMethod = "wxpay"
payType = epay.WechatPay
}
callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(common.ServerAddress + "/log")
notifyUrl, _ := url.Parse(common.ServerAddress + "/api/user/epay/notify")
tradeNo := strconv.FormatInt(time.Now().Unix(), 10)
payMoney := amount
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
@@ -95,9 +112,13 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
amount := req.Amount
if !common.DisplayInCurrencyEnabled {
amount = amount / int(common.QuotaPerUnit)
}
topUp := &model.TopUp{
UserId: id,
Amount: req.Amount,
Amount: amount,
Money: payMoney,
TradeNo: "A" + tradeNo,
CreateTime: time.Now().Unix(),
@@ -111,6 +132,33 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
}
// tradeNo lock
var orderLocks sync.Map
var createLock sync.Mutex
// LockOrder 尝试对给定订单号加锁
func LockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if !ok {
createLock.Lock()
defer createLock.Unlock()
lock, ok = orderLocks.Load(tradeNo)
if !ok {
lock = new(sync.Mutex)
orderLocks.Store(tradeNo, lock)
}
}
lock.(*sync.Mutex).Lock()
}
// UnlockOrder 释放给定订单号的锁
func UnlockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if ok {
lock.(*sync.Mutex).Unlock()
}
}
func EpayNotify(c *gin.Context) {
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
@@ -122,6 +170,7 @@ func EpayNotify(c *gin.Context) {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
return
}
}
verifyInfo, err := client.Verify(params)
@@ -135,11 +184,19 @@ func EpayNotify(c *gin.Context) {
if err != nil {
log.Println("易支付回调写入失败")
}
log.Println("易支付回调签名验证失败")
return
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp == nil {
log.Printf("易支付回调未找到订单: %v", verifyInfo)
return
}
if topUp.Status == "pending" {
topUp.Status = "success"
err := topUp.Update()
@@ -149,13 +206,13 @@ func EpayNotify(c *gin.Context) {
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000)
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*500000), topUp.Money))
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
@@ -169,12 +226,17 @@ func RequestAmount(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < 1 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额不能小于1"})
if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
payMoney := GetAmount(float64(req.Amount), *user)
payMoney := getPayMoney(float64(req.Amount), *user)
if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}

56
controller/usedata.go Normal file
View File

@@ -0,0 +1,56 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/model"
"strconv"
)
func GetAllQuotaDates(c *gin.Context) {
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
username := c.Query("username")
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": dates,
})
return
}
func GetUserQuotaDates(c *gin.Context) {
userId := c.GetInt("id")
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
// 判断时间跨度是否超过 1 个月
if endTimestamp-startTimestamp > 2592000 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "时间跨度不能超过 1 个月",
})
return
}
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": dates,
})
return
}

View File

@@ -7,6 +7,7 @@ import (
"one-api/common"
"one-api/model"
"strconv"
"sync"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
@@ -683,7 +684,7 @@ func ManageUser(c *gin.Context) {
})
return
}
if err := user.HardDelete(); err != nil {
if err := user.Delete(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -724,7 +725,7 @@ func ManageUser(c *gin.Context) {
user.Role = common.RoleCommonUser
}
if err := user.Update(false); err != nil {
if err := user.UpdateAll(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -789,7 +790,11 @@ type topUpRequest struct {
Key string `json:"key"`
}
var lock = sync.Mutex{}
func TopUp(c *gin.Context) {
lock.Lock()
defer lock.Unlock()
req := topUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {

View File

@@ -3,6 +3,7 @@ version: '3.4'
services:
new-api:
image: calciumion/new-api:latest
# build: .
container_name: new-api
restart: always
command: --log-dir /app/logs

13
dto/audio.go Normal file
View File

@@ -0,0 +1,13 @@
package dto
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type AudioResponse struct {
Text string `json:"text"`
}

20
dto/dalle.go Normal file
View File

@@ -0,0 +1,20 @@
package dto
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
}
}

55
dto/error.go Normal file
View File

@@ -0,0 +1,55 @@
package dto
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
Error OpenAIError `json:"error"`
StatusCode int `json:"status_code"`
LocalError bool
}
type GeneralErrorResponse struct {
Error OpenAIError `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}

95
dto/midjourney.go Normal file
View File

@@ -0,0 +1,95 @@
package dto
//type SimpleMjRequest struct {
// Prompt string `json:"prompt"`
// CustomId string `json:"customId"`
// Action string `json:"action"`
// Content string `json:"content"`
//}
type SwapFaceRequest struct {
SourceBase64 string `json:"sourceBase64"`
TargetBase64 string `json:"targetBase64"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
MaskBase64 string `json:"maskBase64"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
}
type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"`
Response MidjourneyResponse
}
type MidjourneyDto struct {
MjId string `json:"id"`
Action string `json:"action"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
Buttons any `json:"buttons"`
MaskBase64 string `json:"maskBase64"`
Properties *Properties `json:"properties"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}
type MidjourneyWithoutStatus struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
}
type ActionButton struct {
CustomId any `json:"customId"`
Emoji any `json:"emoji"`
Label any `json:"label"`
Type any `json:"type"`
Style any `json:"style"`
}
type Properties struct {
FinalPrompt string `json:"finalPrompt"`
FinalZhPrompt string `json:"finalZhPrompt"`
}

6
dto/sensitive.go Normal file
View File

@@ -0,0 +1,6 @@
package dto
type SensitiveResponse struct {
SensitiveWords []string `json:"sensitive_words"`
Content string `json:"content"`
}

141
dto/text_request.go Normal file
View File

@@ -0,0 +1,141 @@
package dto
import "encoding/json"
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
}
type MediaMessage struct {
Type string `json:"type"`
Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"`
}
type MessageImageUrl struct {
Url string `json:"url"`
Detail string `json:"detail"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
func (m Message) StringContent() string {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return stringContent
}
return string(m.Content)
}
func (m Message) IsStringContent() bool {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return true
}
return false
}
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: stringContent,
})
return contentList
}
var arrayContent []json.RawMessage
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "auto"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
}
}
}
return contentList
}
return nil
}

89
dto/text_response.go Normal file
View File

@@ -0,0 +1,89 @@
package dto
type TextResponseWithError struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type SimpleResponse struct {
Usage `json:"usage"`
Error OpenAIError `json:"error"`
Choices []OpenAITextResponseChoice `json:"choices"`
}
type TextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
Role string `json:"role,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
Index int `json:"index,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

24
go.mod
View File

@@ -4,22 +4,23 @@ module one-api
go 1.18
require (
github.com/chai2010/webp v1.1.1
github.com/Calcium-Ion/go-epay v0.0.2
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.16.0
github.com/go-playground/validator/v10 v10.19.0
github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.6
github.com/samber/lo v1.38.1
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.17.0
golang.org/x/crypto v0.21.0
golang.org/x/image v0.15.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
@@ -27,12 +28,13 @@ require (
)
require (
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
@@ -50,7 +52,7 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
@@ -63,10 +65,10 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

57
go.sum
View File

@@ -1,10 +1,16 @@
github.com/Calcium-Ion/go-epay v0.0.1 h1:cRCvwNTkPmmLM5od0p4w0cTcYcAPaAVLYr41ujseDcc=
github.com/Calcium-Ion/go-epay v0.0.1/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -17,8 +23,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
@@ -47,10 +53,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
@@ -79,8 +83,6 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
@@ -106,12 +108,10 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
@@ -141,6 +141,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
@@ -155,9 +157,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@@ -176,17 +177,21 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -197,16 +202,12 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

28
main.go
View File

@@ -13,16 +13,17 @@ import (
"one-api/middleware"
"one-api/model"
"one-api/router"
"one-api/service"
"os"
"strconv"
_ "net/http/pprof"
)
//go:embed web/build
//go:embed web/dist
var buildFS embed.FS
//go:embed web/build/index.html
//go:embed web/dist/index.html
var indexPage []byte
func main() {
@@ -63,10 +64,17 @@ func main() {
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
model.InitChannelCache()
}
if common.RedisEnabled {
go model.SyncTokenCache(common.SyncFrequency)
}
if common.MemoryCacheEnabled {
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
// 数据看板
go model.UpdateQuotaData()
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
@@ -81,7 +89,9 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
go controller.UpdateMidjourneyTaskBulk()
common.SafeGoroutine(func() {
controller.UpdateMidjourneyTaskBulk()
})
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
@@ -96,11 +106,19 @@ func main() {
common.SysLog("pprof enabled")
}
controller.InitTokenEncoders()
service.InitTokenEncoders()
// Initialize HTTP server
server := gin.New()
server.Use(gin.Recovery())
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
common.SysError(fmt.Sprintf("panic detected: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
"type": "new_api_panic",
},
})
}))
// This will cause SSE not to work!!!
//server.Use(gzip.Gzip(gzip.DefaultCompression))
server.Use(middleware.RequestId())

14
makefile Normal file
View File

@@ -0,0 +1,14 @@
FRONTEND_DIR = ./web
BACKEND_DIR = .
.PHONY: all build-frontend start-backend
all: build-frontend start-backend
build-frontend:
@echo "Building frontend..."
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build npm run build
start-backend:
@echo "Starting backend dev server..."
@cd $(BACKEND_DIR) && go run main.go &

View File

@@ -100,32 +100,36 @@ func TokenAuth() func(c *gin.Context) {
}
token, err := model.ValidateUserToken(key)
if err != nil {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
c.Set("token_unlimited_quota", token.UnlimitedQuota)
if !token.UnlimitedQuota {
c.Set("token_quota", token.RemainQuota)
}
if token.ModelLimitsEnabled {
c.Set("token_model_limit_enabled", true)
c.Set("token_model_limit", token.GetModelLimitsMap())
} else {
c.Set("token_model_limit_enabled", false)
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
c.Set("specific_channel_id", parts[1])
} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
}
}

View File

@@ -4,7 +4,11 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relayconstant "one-api/relay/constant"
"one-api/service"
"strconv"
"strings"
@@ -18,98 +22,174 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
userId := c.GetInt("id")
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
var channel *model.Channel
channelId, ok := c.Get("channelId")
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
if channel.Status != common.ChannelStatusEnabled {
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
return
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
var err error
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
// Midjourney
if modelRequest.Model == "" {
modelRequest.Model = "midjourney"
// check token model mapping
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else {
tokenModelLimit = map[string]bool{}
}
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
// token model limit is empty, all models are not allowed
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
// 如果错误,而且渠道为空,说明是没有可用渠道
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
return
}
if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
}
abortWithMessage(c, http.StatusServiceUnavailable, message)
return
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
}
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Next()
}
}
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
var modelRequest ModelRequest
shouldSelectChannel := true
var err error
if strings.Contains(c.Request.URL.Path, "/mj/") {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description)
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else {
// task fetch, task fetch by condition, notify
shouldSelectChannel = false
}
}
modelRequest.Model = midjourneyModel
}
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, err
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
}
}
return &modelRequest, shouldSelectChannel, nil
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("original_model", modelName) // for retry
if channel == nil {
return
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
}

28
middleware/recover.go Normal file
View File

@@ -0,0 +1,28 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"runtime/debug"
)
func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err))
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
"type": "new_api_panic",
},
})
c.Abort()
}
}()
c.Next()
}
}

View File

@@ -5,7 +5,7 @@ import (
"one-api/common"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
@@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.Abort()
common.LogError(c.Request.Context(), message)
}
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
c.JSON(statusCode, gin.H{
"description": description,
"type": "new_api_error",
"code": code,
})
c.Abort()
common.LogError(c.Request.Context(), description)
}

View File

@@ -1,13 +1,17 @@
package model
import (
"errors"
"fmt"
"github.com/samber/lo"
"gorm.io/gorm"
"one-api/common"
"strings"
)
type Ability struct {
Group string `json:"group" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
@@ -25,8 +29,7 @@ func GetGroupModels(group string) []string {
return models
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
var abilities []Ability
func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
@@ -34,9 +37,55 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
trueVal = "true"
}
var err error = nil
var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil {
// 处理错误
return 0, err
}
// 确定要使用的优先级
var priorityToUse int
if retry >= len(priorities) {
// 如果重试次数大于优先级数,则使用最小的优先级
priorityToUse = priorities[len(priorities)-1]
} else {
priorityToUse = priorities[retry]
}
return priorityToUse, nil
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
}
}
return channelQuery
}
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
var abilities []Ability
var err error = nil
channelQuery := getChannelQuery(group, model, retry)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else {
@@ -50,25 +99,20 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
// Randomly choose one
weightSum := uint(0)
for _, ability_ := range abilities {
weightSum += ability_.Weight
weightSum += ability_.Weight + 10
}
if weightSum == 0 {
// All weight is 0, randomly choose one
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
} else {
// Randomly choose one
weight := common.GetRandomInt(int(weightSum))
for _, ability_ := range abilities {
weight -= int(ability_.Weight)
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break
}
// Randomly choose one
weight := common.GetRandomInt(int(weightSum))
for _, ability_ := range abilities {
weight -= int(ability_.Weight) + 10
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break
}
}
} else {
return nil, nil
return nil, errors.New("channel not found")
}
err = DB.First(&channel, "id = ?", channel.Id).Error
return &channel, err
@@ -91,7 +135,16 @@ func (channel *Channel) AddAbilities() error {
abilities = append(abilities, ability)
}
}
return DB.Create(&abilities).Error
if len(abilities) == 0 {
return nil
}
for _, chunk := range lo.Chunk(abilities, 50) {
err := DB.Create(&chunk).Error
if err != nil {
return err
}
}
return nil
}
func (channel *Channel) DeleteAbilities() error {
@@ -118,3 +171,51 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
func FixAbility() (int, error) {
var channelIds []int
count := 0
// Find all channel ids from channel table
err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
return 0, err
}
// Delete abilities of channels that are not in channel table
err = DB.Where("channel_id NOT IN (?)", channelIds).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities of channels that are not in channel table failed: %s", err.Error()))
return 0, err
}
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
count += len(channelIds)
// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err
}
var channels []Channel
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
}
if err != nil {
return 0, err
}
for _, channel := range channels {
err := channel.UpdateAbilities()
if err != nil {
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
} else {
common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
count++
}
}
InitChannelCache()
return count, nil
}

View File

@@ -20,34 +20,88 @@ var (
UserId2StatusCacheSeconds = common.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
var token Token
// 仅用于定时同步缓存
var token2UserId = make(map[string]int)
var token2UserIdLock sync.RWMutex
func cacheSetToken(token *Token) error {
if !common.RedisEnabled {
err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
return token.SelectUpdate()
}
jsonBytes, err := json.Marshal(token)
if err != nil {
return err
}
err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
return err
}
token2UserIdLock.Lock()
defer token2UserIdLock.Unlock()
token2UserId[token.Key] = token.UserId
return nil
}
// CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
func CacheGetTokenByKey(key string) (*Token, error) {
if !common.RedisEnabled {
return GetTokenByKey(key)
}
var token *Token
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
err := DB.Where(keyCol+" = ?", key).First(&token).Error
// 如果缓存中不存在,则从数据库中获取
token, err = GetTokenByKey(key)
if err != nil {
return nil, err
}
jsonBytes, err := json.Marshal(token)
if err != nil {
return nil, err
}
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set token error: " + err.Error())
}
return &token, nil
err = cacheSetToken(token)
return token, nil
}
// 如果缓存中存在,则续期时间
err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second)
err = json.Unmarshal([]byte(tokenObjectString), &token)
return &token, err
return token, err
}
func SyncTokenCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing tokens from database")
token2UserIdLock.Lock()
// 从token2UserId中获取所有的key
var copyToken2UserId = make(map[string]int)
for s, i := range token2UserId {
copyToken2UserId[s] = i
}
token2UserId = make(map[string]int)
token2UserIdLock.Unlock()
for key := range copyToken2UserId {
token, err := GetTokenByKey(key)
if err != nil {
// 如果数据库中不存在,则删除缓存
common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
//delete redis
err := common.RedisDel(fmt.Sprintf("token:%s", key))
if err != nil {
common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
}
} else {
// 如果数据库中存在先检查redis
_, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
// 如果redis中不存在则跳过
continue
}
err = cacheSetToken(token)
if err != nil {
common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
}
}
}
}
}
func CacheGetUserGroup(id int) (group string, err error) {
@@ -68,6 +122,24 @@ func CacheGetUserGroup(id int) (group string, err error) {
return group, err
}
func CacheGetUsername(id int) (username string, err error) {
if !common.RedisEnabled {
return GetUsernameById(id)
}
username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
if err != nil {
username, err = GetUsernameById(id)
if err != nil {
return "", err
}
err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user group error: " + err.Error())
}
}
return username, err
}
func CacheGetUserQuota(id int) (quota int, err error) {
if !common.RedisEnabled {
return GetUserQuota(id)
@@ -193,14 +265,14 @@ func SyncChannelCache(frequency int) {
}
}
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
return GetRandomSatisfiedChannel(group, model, retry)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
@@ -208,40 +280,49 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
endIdx := len(channels)
// choose by priority
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
uniquePriorities := make(map[int]bool)
for _, channel := range channels {
uniquePriorities[int(channel.GetPriority())] = true
}
var sortedUniquePriorities []int
for priority := range uniquePriorities {
sortedUniquePriorities = append(sortedUniquePriorities, priority)
}
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
if retry >= len(uniquePriorities) {
retry = len(uniquePriorities) - 1
}
targetPriority := int64(sortedUniquePriorities[retry])
// get the priority for the given retry number
var targetChannels []*Channel
for _, channel := range channels {
if channel.GetPriority() == targetPriority {
targetChannels = append(targetChannels, channel)
}
}
// 平滑系数
smoothingFactor := 10
// Calculate the total weight of all channels up to endIdx
totalWeight := 0
for _, channel := range channels[:endIdx] {
totalWeight += channel.GetWeight()
for _, channel := range targetChannels {
totalWeight += channel.GetWeight() + smoothingFactor
}
if totalWeight == 0 {
// If all weights are 0, select a channel randomly
return channels[rand.Intn(endIdx)], nil
}
// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight
for _, channel := range channels[:endIdx] {
randomWeight -= channel.GetWeight()
if randomWeight <= 0 {
for _, channel := range targetChannels {
randomWeight -= channel.GetWeight() + smoothingFactor
if randomWeight < 0 {
return channel, nil
}
}
// return the last channel if no channel is found
return channels[endIdx-1], nil
// return null if no channel is not found
return nil, errors.New("channel not found")
}
func CacheGetChannel(id int) (*Channel, error) {

View File

@@ -8,8 +8,9 @@ import (
type Channel struct {
Id int `json:"id"`
Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null;index"`
Key string `json:"key" gorm:"not null"`
OpenAIOrganization *string `json:"openai_organization"`
TestModel *string `json:"test_model"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"`
@@ -21,7 +22,7 @@ type Channel struct {
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(255);default:'default'"`
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
@@ -43,21 +44,39 @@ func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Chan
return channels, err
}
func SearchChannels(keyword string, group string) (channels []*Channel, err error) {
func SearchChannels(keyword string, group string, model string) ([]*Channel, error) {
var channels []*Channel
keyCol := "`key`"
groupCol := "`group`"
modelsCol := "`models`"
// 如果是 PostgreSQL使用双引号
if common.UsingPostgreSQL {
keyCol = `"key"`
groupCol = `"group"`
modelsCol = `"models"`
}
// 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
// 构造WHERE子句
var whereClause string
var args []interface{}
if group != "" {
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
err = DB.Omit("key").Where("(id = ? or name LIKE ? or "+keyCol+" = ?) and "+groupCol+" LIKE ?", common.String2Int(keyword), keyword+"%", keyword, "%"+group+"%").Find(&channels).Error
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%")
} else {
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
}
return channels, err
// 执行查询
err := baseQuery.Where(whereClause, args...).Find(&channels).Error
if err != nil {
return nil, err
}
return channels, nil
}
func GetChannelById(id int, selectAll bool) (*Channel, error) {

View File

@@ -9,7 +9,7 @@ import (
)
type Log struct {
Id int `json:"id;index:idx_created_at_id,priority:1"`
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
@@ -20,6 +20,8 @@ type Log struct {
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
UseTime int `json:"use_time" gorm:"default:0"`
IsStream bool `json:"is_stream" gorm:"default:false"`
ChannelId int `json:"channel" gorm:"index"`
TokenId int `json:"token_id" gorm:"default:0;index"`
}
@@ -41,9 +43,10 @@ func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled {
return
}
username, _ := CacheGetUsername(userId)
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
Username: username,
CreatedAt: common.GetTimestamp(),
Type: logType,
Content: content,
@@ -54,14 +57,15 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int) {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
username, _ := CacheGetUsername(userId)
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeConsume,
Content: content,
@@ -72,11 +76,18 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
Quota: quota,
ChannelId: channelId,
TokenId: tokenId,
UseTime: useTimeSeconds,
IsStream: isStream,
}
err := DB.Create(log).Error
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
common.SafeGoroutine(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
})
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {

View File

@@ -5,9 +5,11 @@ import (
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"log"
"one-api/common"
"os"
"strings"
"sync"
"time"
)
@@ -60,6 +62,7 @@ func chooseDB() (*gorm.DB, error) {
dsn += "?parseTime=true"
}
}
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -90,6 +93,12 @@ func InitDB() (err error) {
if !common.IsMasterNode {
return nil
}
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
_, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
_, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
_, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
}
common.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
@@ -127,6 +136,10 @@ func InitDB() (err error) {
if err != nil {
return err
}
err = db.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
@@ -144,3 +157,33 @@ func CloseDB() error {
err = sqlDB.Close()
return err
}
var (
lastPingTime time.Time
pingMutex sync.Mutex
)
func PingDB() error {
pingMutex.Lock()
defer pingMutex.Unlock()
if time.Since(lastPingTime) < time.Second*10 {
return nil
}
sqlDB, err := DB.DB()
if err != nil {
log.Printf("Error getting sql.DB from GORM: %v", err)
return err
}
err = sqlDB.Ping()
if err != nil {
log.Printf("Error pinging DB: %v", err)
return err
}
lastPingTime = time.Now()
common.SysLog("Database pinged successfully")
return nil
}

View File

@@ -4,20 +4,23 @@ type Midjourney struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
Action string `json:"action" gorm:"type:varchar(40);index"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
SubmitTime int64 `json:"submit_time" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
ImageUrl string `json:"image_url"`
Status string `json:"status"`
Progress string `json:"progress"`
Status string `json:"status" gorm:"type:varchar(20);index"`
Progress string `json:"progress" gorm:"type:varchar(30);index"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
Buttons string `json:"buttons"`
Properties string `json:"properties"`
}
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
@@ -96,7 +99,7 @@ func GetAllUnFinishTasks() []*Midjourney {
return tasks
}
func GetByMJId(mjId string) *Midjourney {
func GetByOnlyMJId(mjId string) *Midjourney {
var mj *Midjourney
var err error
err = DB.Where("mj_id = ?", mjId).First(&mj).Error
@@ -106,6 +109,26 @@ func GetByMJId(mjId string) *Midjourney {
return mj
}
func GetByMJId(userId int, mjId string) *Midjourney {
var mj *Midjourney
var err error
err = DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error
if err != nil {
return nil
}
return mj
}
func GetByMJIds(userId int, mjIds []string) []*Midjourney {
var mj []*Midjourney
var err error
err = DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error
if err != nil {
return nil
}
return mj
}
func GetMjByuId(id int) *Midjourney {
var mj *Midjourney
var err error
@@ -132,8 +155,14 @@ func (midjourney *Midjourney) Update() error {
return err
}
func MjBulkUpdate(taskIDs []string, params map[string]any) error {
func MjBulkUpdate(mjIds []string, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("mj_id in (?)", taskIDs).
Where("mj_id in (?)", mjIds).
Updates(params).Error
}
func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("id in (?)", taskIDs).
Updates(params).Error
}

View File

@@ -2,6 +2,7 @@ package model
import (
"one-api/common"
"one-api/constant"
"strconv"
"strings"
"time"
@@ -30,21 +31,27 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
common.OptionMap["SMTPAccount"] = ""
common.OptionMap["SMTPToken"] = ""
common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled)
common.OptionMap["Notice"] = ""
common.OptionMap["About"] = ""
common.OptionMap["HomePageContent"] = ""
@@ -53,12 +60,16 @@ func InitOptionMap() {
common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = ""
common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = ""
common.OptionMap["EpayKey"] = ""
common.OptionMap["Price"] = strconv.FormatFloat(common.Price, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
common.OptionMap["TelegramBotName"] = ""
common.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["WeChatServerToken"] = ""
common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
@@ -74,8 +85,20 @@ func InitOptionMap() {
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["ChatLink2"] = common.ChatLink2
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -132,7 +155,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") {
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
@@ -145,18 +168,46 @@ func updateOptionMap(key string, value string) (err error) {
common.GitHubOAuthEnabled = boolValue
case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue
case "TelegramOAuthEnabled":
common.TelegramOAuthEnabled = boolValue
case "TurnstileCheckEnabled":
common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
common.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue
case "EmailAliasRestrictionEnabled":
common.EmailAliasRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
common.AutomaticEnableChannelEnabled = boolValue
case "LogConsumeEnabled":
common.LogConsumeEnabled = boolValue
case "DisplayInCurrencyEnabled":
common.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled":
common.DisplayTokenStatEnabled = boolValue
case "DrawingEnabled":
common.DrawingEnabled = boolValue
case "DataExportEnabled":
common.DataExportEnabled = boolValue
case "DefaultCollapseSidebar":
common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled":
constant.MjNotifyEnabled = boolValue
case "MjModeClearEnabled":
constant.MjModeClearEnabled = boolValue
case "CheckSensitiveEnabled":
constant.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
constant.CheckSensitiveOnPromptEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled":
constant.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled":
common.SMTPSSLEnabled = boolValue
}
}
switch key {
@@ -177,12 +228,16 @@ func updateOptionMap(key string, value string) (err error) {
common.ServerAddress = value
case "PayAddress":
common.PayAddress = value
case "CustomCallbackAddress":
common.CustomCallbackAddress = value
case "EpayId":
common.EpayId = value
case "EpayKey":
common.EpayKey = value
case "Price":
common.Price, _ = strconv.ParseFloat(value, 64)
case "MinTopUp":
common.MinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":
@@ -201,6 +256,10 @@ func updateOptionMap(key string, value string) (err error) {
common.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL":
common.WeChatAccountQRCodeImageURL = value
case "TelegramBotToken":
common.TelegramBotToken = value
case "TelegramBotName":
common.TelegramBotName = value
case "TurnstileSiteKey":
common.TurnstileSiteKey = value
case "TurnstileSecretKey":
@@ -217,6 +276,10 @@ func updateOptionMap(key string, value string) (err error) {
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
common.DataExportInterval, _ = strconv.Atoi(value)
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
@@ -227,10 +290,16 @@ func updateOptionMap(key string, value string) (err error) {
common.TopUpLink = value
case "ChatLink":
common.ChatLink = value
case "ChatLink2":
common.ChatLink2 = value
case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords":
constant.SensitiveWordsFromString(value)
case "StreamCacheQueueLength":
constant.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
return err
}

View File

@@ -8,16 +8,17 @@ import (
)
type Redemption struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Quota int `json:"quota" gorm:"default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request
UsedUserId int `json:"used_user_id"`
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Quota int `json:"quota" gorm:"default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request
UsedUserId int `json:"used_user_id"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) {
@@ -55,7 +56,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if common.UsingPostgreSQL {
keyCol = `"key"`
}
common.RandomSleep()
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil {

View File

@@ -10,17 +10,20 @@ import (
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
@@ -45,7 +48,9 @@ func ValidateUserToken(key string) (token *Token, err error) {
token, err = CacheGetTokenByKey(key)
if err == nil {
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽 token.Status == common.TokenStatusExhausted " + key)
keyPrefix := key[:3]
keySuffix := key[len(key)-3:]
return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
@@ -71,7 +76,9 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New(fmt.Sprintf("%s 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", token.Key, token.RemainQuota))
keyPrefix := key[:3]
keySuffix := key[len(key)-3:]
return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
}
return token, nil
}
@@ -98,6 +105,16 @@ func GetTokenById(id int) (*Token, error) {
return &token, err
}
func GetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
var token Token
err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
}
func (token *Token) Insert() error {
var err error
err = DB.Create(token).Error
@@ -107,7 +124,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error
return err
}
@@ -122,6 +139,36 @@ func (token *Token) Delete() error {
return err
}
func (token *Token) IsModelLimitsEnabled() bool {
return token.ModelLimitsEnabled
}
func (token *Token) GetModelLimits() []string {
if token.ModelLimits == "" {
return []string{}
}
return strings.Split(token.ModelLimits, ",")
}
func (token *Token) GetModelLimitsMap() map[string]bool {
limits := token.GetModelLimits()
limitsMap := make(map[string]bool)
for _, limit := range limits {
limitsMap[limit] = true
}
return limitsMap
}
func DisableModelLimits(tokenId int) error {
token, err := GetTokenById(tokenId)
if err != nil {
return err
}
token.ModelLimitsEnabled = false
token.ModelLimits = ""
return token.Update()
}
func DeleteTokenById(id int, userId int) (err error) {
// Why we need userId here? In case user want to delete other's token.
if id == 0 || userId == 0 {

131
model/usedata.go Normal file
View File

@@ -0,0 +1,131 @@
package model
import (
"fmt"
"gorm.io/gorm"
"one-api/common"
"sync"
"time"
)
// QuotaData 柱状图数据
type QuotaData struct {
Id int `json:"id"`
UserID int `json:"user_id" gorm:"index"`
Username string `json:"username" gorm:"index:idx_qdt_model_user_name,priority:2;size:64;default:''"`
ModelName string `json:"model_name" gorm:"index:idx_qdt_model_user_name,priority:1;size:64;default:''"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_qdt_created_at,priority:2"`
TokenUsed int `json:"token_used" gorm:"default:0"`
Count int `json:"count" gorm:"default:0"`
Quota int `json:"quota" gorm:"default:0"`
}
func UpdateQuotaData() {
// recover
defer func() {
if r := recover(); r != nil {
common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
}
}()
for {
if common.DataExportEnabled {
common.SysLog("正在更新数据看板数据...")
SaveQuotaDataCache()
}
time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
}
}
var CacheQuotaData = make(map[string]*QuotaData)
var CacheQuotaDataLock = sync.Mutex{}
func logQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64, tokenUsed int) {
key := fmt.Sprintf("%d-%s-%s-%d", userId, username, modelName, createdAt)
quotaData, ok := CacheQuotaData[key]
if ok {
quotaData.Count += 1
quotaData.Quota += quota
} else {
quotaData = &QuotaData{
UserID: userId,
Username: username,
ModelName: modelName,
CreatedAt: createdAt,
Count: 1,
Quota: quota,
TokenUsed: tokenUsed,
}
}
CacheQuotaData[key] = quotaData
}
func LogQuotaData(userId int, username string, modelName string, quota int, createdAt int64, tokenUsed int) {
// 只精确到小时
createdAt = createdAt - (createdAt % 3600)
CacheQuotaDataLock.Lock()
defer CacheQuotaDataLock.Unlock()
logQuotaDataCache(userId, username, modelName, quota, createdAt, tokenUsed)
}
func SaveQuotaDataCache() {
CacheQuotaDataLock.Lock()
defer CacheQuotaDataLock.Unlock()
size := len(CacheQuotaData)
// 如果缓存中有数据,就保存到数据库中
// 1. 先查询数据库中是否有数据
// 2. 如果有数据,就更新数据
// 3. 如果没有数据,就插入数据
for _, quotaData := range CacheQuotaData {
quotaDataDB := &QuotaData{}
DB.Table("quota_data").Where("user_id = ? and username = ? and model_name = ? and created_at = ?",
quotaData.UserID, quotaData.Username, quotaData.ModelName, quotaData.CreatedAt).First(quotaDataDB)
if quotaDataDB.Id > 0 {
//quotaDataDB.Count += quotaData.Count
//quotaDataDB.Quota += quotaData.Quota
//DB.Table("quota_data").Save(quotaDataDB)
increaseQuotaData(quotaData.UserID, quotaData.Username, quotaData.ModelName, quotaData.Count, quotaData.Quota, quotaData.CreatedAt)
} else {
DB.Table("quota_data").Create(quotaData)
}
}
CacheQuotaData = make(map[string]*QuotaData)
common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
}
func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64) {
err := DB.Table("quota_data").Where("user_id = ? and username = ? and model_name = ? and created_at = ?",
userId, username, modelName, createdAt).Updates(map[string]interface{}{
"count": gorm.Expr("count + ?", count),
"quota": gorm.Expr("quota + ?", quota),
}).Error
if err != nil {
common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
}
}
func GetQuotaDataByUsername(username string, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
var quotaDatas []*QuotaData
// 从quota_data表中查询数据
err = DB.Table("quota_data").Where("username = ? and created_at >= ? and created_at <= ?", username, startTime, endTime).Find(&quotaDatas).Error
return quotaDatas, err
}
func GetQuotaDataByUserId(userId int, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
var quotaDatas []*QuotaData
// 从quota_data表中查询数据
err = DB.Table("quota_data").Where("user_id = ? and created_at >= ? and created_at <= ?", userId, startTime, endTime).Find(&quotaDatas).Error
return quotaDatas, err
}
func GetAllQuotaDates(startTime int64, endTime int64, username string) (quotaData []*QuotaData, err error) {
if username != "" {
return GetQuotaDataByUsername(username, startTime, endTime)
}
var quotaDatas []*QuotaData
// 从quota_data表中查询数据
// only select model_name, sum(count) as count, sum(quota) as quota, model_name, created_at from quota_data group by model_name, created_at;
//err = DB.Table("quota_data").Where("created_at >= ? and created_at <= ?", startTime, endTime).Find(&quotaDatas).Error
err = DB.Table("quota_data").Select("model_name, sum(count) as count, sum(quota) as quota, created_at").Where("created_at >= ? and created_at <= ?", startTime, endTime).Group("model_name, created_at").Find(&quotaDatas).Error
return quotaDatas, err
}

View File

@@ -3,9 +3,12 @@ package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
"strconv"
"strings"
"time"
"gorm.io/gorm"
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
@@ -20,12 +23,13 @@ type User struct {
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
@@ -69,8 +73,26 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
return users, err
}
func SearchUsers(keyword string) (users []*User, err error) {
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
func SearchUsers(keyword string) ([]*User, error) {
var users []*User
var err error
// 尝试将关键字转换为整数ID
keywordInt, err := strconv.Atoi(keyword)
if err == nil {
// 如果转换成功按照ID搜索用户
err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error
if err != nil || len(users) > 0 {
// 如果依据ID找到用户或者发生错误返回结果或错误
return users, err
}
}
// 如果ID转换失败或者没有找到用户依据其他字段进行模糊搜索
err = DB.Unscoped().Omit("password").
Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").
Find(&users).Error
return users, err
}
@@ -202,9 +224,34 @@ func (user *User) Update(updatePassword bool) error {
}
}
newUser := *user
DB.First(&user, user.Id)
err = DB.Model(user).Updates(newUser).Error
if err == nil {
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
}
}
return err
}
func (user *User) UpdateAll(updatePassword bool) error {
var err error
if updatePassword {
user.Password, err = common.Password2Hash(user.Password)
if err != nil {
return err
}
}
newUser := *user
DB.First(&user, user.Id)
err = DB.Model(user).Select("*").Updates(newUser).Error
if err == nil {
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
}
}
return err
}
@@ -281,6 +328,17 @@ func (user *User) FillUserByUsername() error {
return nil
}
func (user *User) FillUserByTelegramId() error {
if user.TelegramId == "" {
return errors.New("Telegram id 为空!")
}
err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("该 Telegram 账户未绑定")
}
return nil
}
func IsEmailAlreadyTaken(email string) bool {
return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
}
@@ -297,6 +355,10 @@ func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
func IsTelegramIdAlreadyTaken(telegramId string) bool {
return DB.Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
}
func ResetUserPasswordByEmail(email string, password string) error {
if email == "" || password == "" {
return errors.New("邮箱地址或密码为空!")
@@ -447,7 +509,7 @@ func updateUserRequestCount(id int, count int) {
}
}
func GetUsernameById(id int) (username string) {
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
return username
func GetUsernameById(id int) (username string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
return username, err
}

21
relay/channel/adapter.go Normal file
View File

@@ -0,0 +1,21 @@
package channel
import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
relaycommon "one-api/relay/common"
)
type Adaptor interface {
// Init IsStream bool
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}

View File

@@ -0,0 +1,8 @@
package ai360
var ModelList = []string{
"360GPT_S2_V9",
"embedding-bert-512-v1",
"embedding_s1_v1",
"semantic_similarity_s1_v1",
}

View File

@@ -0,0 +1,80 @@
package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
if info.RelayMode == constant.RelayModeEmbeddings {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
if info.IsStream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := requestOpenAI2Ali(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = aliStreamHandler(c, resp)
} else {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,8 @@
package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
}
var ChannelName = "ali"

72
relay/channel/ali/dto.go Normal file
View File

@@ -0,0 +1,72 @@
package ali
type AliMessage struct {
Content string `json:"content"`
Role string `json:"role"`
}
type AliInput struct {
Prompt string `json:"prompt,omitempty"`
//History []AliMessage `json:"history,omitempty"`
Messages []AliMessage `json:"messages"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input,omitempty"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}

View File

@@ -1,4 +1,4 @@
package controller
package ali
import (
"bufio"
@@ -7,119 +7,46 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
const EnableSearchModelSuffix = "-internet"
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
//prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: string(message.Content),
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = string(message.Content)
break
}
messages = append(messages, AliMessage{
User: string(message.Content),
Bot: string(request.Messages[i+1].Content),
})
i++
}
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
Prompt: prompt,
History: messages,
//Prompt: prompt,
Messages: messages,
},
Parameters: AliParameters{
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
EnableSearch: enableSearch,
},
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP,
// TopK: 50,
// //Seed: 0,
// //EnableSearch: false,
//},
}
}
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
@@ -130,21 +57,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque
}
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var aliResponse AliEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
@@ -157,7 +84,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -165,16 +92,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
@@ -183,22 +110,22 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
content, _ := json.Marshal(response.Output.Text)
choice := OpenAITextResponseChoice{
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: Message{
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := OpenAITextResponse{
fullTextResponse := dto.OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: Usage{
Choices: []dto.OpenAITextResponseChoice{choice},
Usage: dto.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
@@ -207,25 +134,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := ChatCompletionsStreamResponse{
response := dto.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -255,7 +182,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
}
stopChan <- true
}()
setEventStreamHeaders(c)
service.SetEventStreamHeaders(c)
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
@@ -288,28 +215,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
@@ -321,7 +248,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)

View File

@@ -0,0 +1,52 @@
package channel
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/relay/common"
"one-api/service"
)
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, req, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := service.GetHttpClient().Do(req)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
}

View File

@@ -0,0 +1,92 @@
package baidu
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.UpstreamModelName {
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-Bot-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
var accessToken string
var err error
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
return "", err
}
fullRequestURL += "?access_token=" + accessToken
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := requestOpenAI2Baidu(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = baiduStreamHandler(c, resp)
} else {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,12 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"Embedding-V1",
}
var ChannelName = "baidu"

View File

@@ -0,0 +1,71 @@
package baidu
import (
"one-api/dto"
"time"
)
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type Error struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage dto.Usage `json:"usage"`
Error
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage dto.Usage `json:"usage"`
Error
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
type BaiduTokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}

View File

@@ -1,4 +1,4 @@
package controller
package baidu
import (
"bufio"
@@ -9,6 +9,9 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"sync"
"time"
@@ -16,91 +19,15 @@ import (
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
type BaiduTokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type BaiduError struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
Content: string(message.Content),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: string(message.Content),
})
}
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
return &BaiduChatRequest{
Messages: messages,
@@ -108,57 +35,57 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
}
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
content, _ := json.Marshal(response.Result)
choice := OpenAITextResponseChoice{
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: Message{
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
fullTextResponse := dto.OpenAITextResponse{
Id: response.Id,
Object: "chat.completion",
Created: response.Created,
Choices: []OpenAITextResponseChoice{choice},
Choices: []dto.OpenAITextResponseChoice{choice},
Usage: response.Usage,
}
return &fullTextResponse
}
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &stopFinishReason
choice.FinishReason = &relaycommon.StopFinishReason
}
response := ChatCompletionsStreamResponse{
response := dto.ChatCompletionsStreamResponse{
Id: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{
Input: request.ParseInput(),
}
}
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
@@ -167,8 +94,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
return &openAIEmbeddingResponse
}
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -195,7 +122,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
}
stopChan <- true
}()
setEventStreamHeaders(c)
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -225,28 +152,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var baiduResponse BaiduChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -258,7 +185,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -266,23 +193,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
return nil, &fullTextResponse.Usage
}
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var baiduResponse BaiduEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -294,7 +221,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -337,7 +264,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := impatientHTTPClient.Do(req)
res, err := service.GetImpatientHttpClient().Do(req)
if err != nil {
return nil, err
}

View File

@@ -0,0 +1,81 @@
package claude
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"strings"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
} else {
a.RequestMode = RequestModeCompletion
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if a.RequestMode == RequestModeMessage {
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-api-key", info.ApiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeCompletion {
return requestOpenAI2ClaudeComplete(*request), nil
} else {
return requestOpenAI2ClaudeMessage(*request)
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
} else {
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,13 @@
package claude
var ModelList = []string{
"claude-instant-1.2",
"claude-2",
"claude-2.0",
"claude-2.1",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
}
var ChannelName = "claude"

View File

@@ -0,0 +1,68 @@
package claude
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeMediaMessage struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
}
type ClaudeMessageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
}
type ClaudeMessage struct {
Role string `json:"role"`
Content any `json:"content"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
System string `json:"system,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeResponse struct {
Id string `json:"id"`
Type string `json:"type"`
Content []ClaudeMediaMessage `json:"content"`
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
Usage ClaudeUsage `json:"usage"`
Index int `json:"index"` // stream only
Delta *ClaudeMediaMessage `json:"delta"` // stream only
Message *ClaudeResponse `json:"message"` // stream only: message_start
}
//type ClaudeResponseChoice struct {
// Index int `json:"index"`
// Type string `json:"type"`
//}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

View File

@@ -0,0 +1,344 @@
package claude
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
"strings"
)
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "end_turn":
return "stop"
case "max_tokens":
return "length"
default:
return reason
}
}
func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
}
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
claudeMessages := make([]ClaudeMessage, 0)
for _, message := range textRequest.Messages {
if message.Role == "system" {
claudeRequest.System = message.StringContent()
} else {
claudeMessage := ClaudeMessage{
Role: message.Role,
}
if message.IsStringContent() {
claudeMessage.Content = message.StringContent()
} else {
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
for _, mediaMessage := range message.ParseContent() {
claudeMediaMessage := ClaudeMediaMessage{
Type: mediaMessage.Type,
}
if mediaMessage.Type == "text" {
claudeMediaMessage.Text = mediaMessage.Text
} else {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
claudeMediaMessage.Type = "image"
claudeMediaMessage.Source = &ClaudeMessageSource{
Type: "base64",
}
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data
} else {
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
if err != nil {
return nil, err
}
claudeMediaMessage.Source.MediaType = "image/" + format
claudeMediaMessage.Source.Data = base64String
}
}
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
}
claudeMessage.Content = claudeMediaMessages
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeRequest.Prompt = ""
claudeRequest.Messages = claudeMessages
return &claudeRequest, nil
}
func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
var response dto.ChatCompletionsStreamResponse
var claudeUsage *ClaudeUsage
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
} else {
if claudeResponse.Type == "message_start" {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
claudeUsage = &claudeResponse.Message.Usage
} else if claudeResponse.Type == "content_block_delta" {
choice.Index = claudeResponse.Index
choice.Delta.Content = claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
claudeUsage = &claudeResponse.Usage
}
}
if claudeUsage == nil {
claudeUsage = &ClaudeUsage{}
}
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
}
func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
choices := make([]dto.OpenAITextResponseChoice, 0)
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
}
if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: content,
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choices = append(choices, choice)
} else {
fullTextResponse.Id = claudeResponse.Id
for i, message := range claudeResponse.Content {
content, _ := json.Marshal(message.Text)
choice := dto.OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choices = append(choices, choice)
}
}
fullTextResponse.Choices = choices
return &fullTextResponse
}
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage
usage = &dto.Usage{}
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
modelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else {
return true
}
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = modelName
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
} else {
if usage.CompletionTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
}
}
return nil, usage
}
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
if requestMode == RequestModeCompletion {
usage.PromptTokens = promptTokens
usage.CompletionTokens = completionTokens
usage.TotalTokens = promptTokens + completionTokens
} else {
usage.PromptTokens = claudeResponse.Usage.InputTokens
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -0,0 +1,67 @@
package gemini
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
version := "v1"
if info.ApiVersion != "" {
version = info.ApiVersion
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return CovertGemini2OpenAI(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = geminiChatStreamHandler(c, resp)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,12 @@
package gemini
const (
GeminiVisionMaxImageNum = 16
)
var ModelList = []string{
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
}
var ChannelName = "google gemini"

View File

@@ -0,0 +1,62 @@
package gemini
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
}

View File

@@ -1,4 +1,4 @@
package controller
package gemini
import (
"bufio"
@@ -7,77 +7,36 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
const (
GeminiVisionMaxImageNum = 16
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
//SafetySettings: []GeminiChatSafetySettings{
// {
// Category: "HARM_CATEGORY_HARASSMENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_HATE_SPEECH",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
//},
SafetySettings: []GeminiChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: common.GeminiSafetySetting,
},
},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
@@ -97,7 +56,7 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
Role: message.Role,
Parts: []GeminiPart{
{
Text: string(message.Content),
Text: message.StringContent(),
},
},
}
@@ -106,16 +65,16 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
imageNum := 0
for _, part := range openaiContent {
if part.Type == ContentTypeText {
if part.Type == dto.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == ContentTypeImageURL {
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
@@ -154,11 +113,6 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
return &geminiRequest
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
}
func (g *GeminiChatResponse) GetResponseText() string {
if g == nil {
return ""
@@ -169,41 +123,25 @@ func (g *GeminiChatResponse) GetResponseText() string {
return ""
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
content, _ := json.Marshal("")
for i, candidate := range response.Candidates {
choice := OpenAITextResponseChoice{
choice := dto.OpenAITextResponseChoice{
Index: i,
Message: Message{
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: stopFinishReason,
FinishReason: relaycommon.StopFinishReason,
}
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
if len(candidate.Content.Parts) > 0 {
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
choice.Message.Content = content
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
@@ -211,18 +149,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &stopFinishReason
var response ChatCompletionsStreamResponse
choice.FinishReason = &relaycommon.StopFinishReason
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
@@ -252,7 +190,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
}
stopChan <- true
}()
setEventStreamHeaders(c)
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -264,14 +202,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice ChatCompletionsStreamResponseChoice
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := ChatCompletionsStreamResponse{
response := dto.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "gemini-pro",
Choices: []ChatCompletionsStreamResponseChoice{choice},
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
@@ -287,28 +225,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",
@@ -318,8 +256,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
usage := Usage{
completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
@@ -327,7 +265,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)

View File

@@ -0,0 +1,9 @@
package lingyiwanwu
// https://platform.lingyiwanwu.com/docs
var ModelList = []string{
"yi-34b-chat-0205",
"yi-34b-chat-200k",
"yi-vl-plus",
}

View File

@@ -0,0 +1,7 @@
package moonshot
var ModelList = []string{
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k",
}

View File

@@ -0,0 +1,73 @@
package ollama
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return info.BaseUrl + "/api/embeddings", nil
default:
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil
default:
return requestOpenAI2Ollama(*request), nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,5 @@
package ollama
var ModelList []string
var ChannelName = "ollama"

View File

@@ -0,0 +1,26 @@
package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
}
type OllamaEmbeddingRequest struct {
Model string `json:"model,omitempty"`
Prompt any `json:"prompt,omitempty"`
}
type OllamaEmbeddingResponse struct {
Embedding []float64 `json:"embedding,omitempty"`
}
//type OllamaOptions struct {
//}

View File

@@ -0,0 +1,108 @@
package ollama
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
})
}
str, ok := request.Stop.(string)
var Stop []string
if ok {
Stop = []string{str}
} else {
Stop, _ = request.Stop.([]string)
}
return &OllamaRequest{
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
}
}
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Prompt: request.Input,
}
}
func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: ollamaEmbeddingResponse.Embedding,
Object: "embedding",
})
usage := &dto.Usage{
TotalTokens: promptTokens,
CompletionTokens: 0,
PromptTokens: promptTokens,
}
embeddingResponse := &dto.OpenAIEmbeddingResponse{
Object: "list",
Data: data,
Model: model,
Usage: *usage,
}
doResponseBody, err := json.Marshal(embeddingResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
// Copy headers
for k, v := range resp.Header {
// 删除任何现有的相同头部,以防止重复添加头部
c.Writer.Header().Del(k)
for _, vv := range v {
c.Writer.Header().Add(k, vv)
}
}
// reset content length
c.Writer.Header().Del("Content-Length")
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}

View File

@@ -0,0 +1,98 @@
package openai
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
type Adaptor struct {
ChannelType int
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.ChannelType = info.ChannelType
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.ChannelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
}
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
if info.ChannelType == common.ChannelTypeAzure {
req.Header.Set("api-key", info.ApiKey)
return nil
}
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
req.Header.Set("OpenAI-Organization", info.Organization)
}
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
//if info.ChannelType == common.ChannelTypeOpenRouter {
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
// req.Header.Set("X-Title", "One API")
//}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
switch a.ChannelType {
case common.ChannelType360:
return ai360.ModelList
case common.ChannelTypeMoonshot:
return moonshot.ModelList
case common.ChannelTypeLingYiWanWu:
return lingyiwanwu.ModelList
default:
return ModelList
}
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,21 @@
package openai
var ModelList = []string{
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview",
"gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
"text-moderation-latest", "text-moderation-stable",
"text-davinci-edit-001",
"davinci-002", "babbage-002",
"dall-e-2", "dall-e-3",
"whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
}
var ChannelName = "openai"

View File

@@ -1,4 +1,4 @@
package controller
package openai
import (
"bufio"
@@ -8,12 +8,16 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/dto"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"sync"
"time"
)
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
//checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@@ -33,11 +37,10 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
defer close(stopChan)
defer close(dataChan)
var wg sync.WaitGroup
go func() {
wg.Add(1)
defer wg.Done()
var streamItems []string
var streamItems []string // store stream items
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
@@ -54,28 +57,46 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
}
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case RelayModeChatCompletions:
var streamResponses []ChatCompletionsStreamResponseSimple
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return // just ignore the error
}
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.Content)
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.Content)
}
}
}
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.Content)
}
}
}
case RelayModeCompletions:
var streamResponses []CompletionsStreamResponse
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return // just ignore the error
}
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
}
@@ -83,9 +104,9 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
// wait data out
time.Sleep(2 * time.Second)
}
stopChan <- true
common.SafeSend(stopChan, true)
}()
setEventStreamHeaders(c)
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -102,30 +123,30 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
wg.Wait()
return nil, responseTextBuilder.String()
}
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var simpleResponse dto.SimpleResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
err = json.Unmarshal(responseBody, &simpleResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
if simpleResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: simpleResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
@@ -140,23 +161,24 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Usage.TotalTokens == 0 {
if simpleResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += countTokenText(string(choice.Message.Content), model)
for _, choice := range simpleResponse.Choices {
ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false)
completionTokens += ctkm
}
textResponse.Usage = Usage{
simpleResponse.Usage = dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
return nil, &textResponse.Usage
return nil, &simpleResponse.Usage
}

View File

@@ -0,0 +1,59 @@
package palm
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,7 @@
package palm
var ModelList = []string{
"PaLM-2",
}
var ChannelName = "google palm"

38
relay/channel/palm/dto.go Normal file
View File

@@ -0,0 +1,38 @@
package palm
import "one-api/dto"
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK uint `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []dto.Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
}

View File

@@ -1,4 +1,4 @@
package controller
package palm
import (
"encoding/json"
@@ -7,47 +7,15 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK uint `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
}
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
@@ -59,7 +27,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
Content: string(message.Content),
Content: message.StringContent(),
}
if message.Role == "user" {
palmMessage.Author = "0"
@@ -71,15 +39,15 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
return &palmRequest
}
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
content, _ := json.Marshal(candidate.Content)
choice := OpenAITextResponseChoice{
choice := dto.OpenAITextResponseChoice{
Index: i,
Message: Message{
Message: dto.Message{
Role: "assistant",
Content: content,
},
@@ -90,20 +58,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
return &fullTextResponse
}
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
}
choice.FinishReason = &stopFinishReason
var response ChatCompletionsStreamResponse
choice.FinishReason = &relaycommon.StopFinishReason
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "palm2"
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response
}
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
@@ -144,7 +112,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
dataChan <- string(jsonResponse)
stopChan <- true
}()
setEventStreamHeaders(c)
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@@ -157,28 +125,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
@@ -188,8 +156,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
usage := Usage{
completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
@@ -197,7 +165,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)

View File

@@ -0,0 +1,63 @@
package perplexity
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if request.TopP >= 1 {
request.TopP = 0.99
}
return requestOpenAI2Perplexity(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,7 @@
package perplexity
var ModelList = []string{
"sonar-small-chat", "sonar-small-online", "sonar-medium-chat", "sonar-medium-online", "mistral-7b-instruct", "mixtral-8x7b-instruct",
}
var ChannelName = "perplexity"

View File

@@ -0,0 +1,21 @@
package perplexity
import "one-api/dto"
func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
})
}
return &dto.GeneralOpenAIRequest{
Model: request.Model,
Stream: request.Stream,
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
}
}

Some files were not shown because too many files have changed in this diff Show More