Compare commits

...

161 Commits

Author SHA1 Message Date
CalciumIon
b40c2e1071 feat: 美化日志页面
(cherry picked from commit 90daa38d5bea7b158ebed9990f042f6bf8567eb3)
2024-11-05 20:45:01 +08:00
Xyfacai
afc1e92ed0 fix: log table unknown ws prop error 2024-11-05 20:20:19 +08:00
1808837298@qq.com
ee04dbd9dd feat: 日志详情完善
(cherry picked from commit ec79110c99e9b4c076c5f7b8285e535b9c5052db)
2024-11-05 20:19:58 +08:00
CalciumIon
e5588fc1ee Update README.md 2024-11-05 19:48:03 +08:00
Calcium-Ion
a859ff5985 Merge pull request #551 from Calcium-Ion/realtime
feat: support openai realtime api
2024-11-05 19:45:43 +08:00
CalciumIon
0a80231e18 chore: 删除无用日志 2024-11-05 19:41:38 +08:00
CalciumIon
7b1ff41e4c fix: mistral adaptor 2024-11-05 19:32:51 +08:00
1808837298@qq.com
4e0c522cd0 fix: realtime计费
(cherry picked from commit fdfea8726c6d86d3844af1ac18d7b3df908f26a7)
2024-11-05 19:29:06 +08:00
1808837298@qq.com
f08f7ae940 fix: channel test
(cherry picked from commit 052bdab1c45b3a4ba5f079afc763f54e751b1cd7)
2024-11-05 19:28:58 +08:00
Xyfacai
be64408a25 fix(realtime): 修复ws 握手失败、计费问题
(cherry picked from commit 618dffc43fd5a5f4065944db87761f9ee18e44d3)
2024-11-05 19:28:46 +08:00
Xyfacai
d596699250 refactor: realtime log
(cherry picked from commit fd24dc467bfc360008b313220e607f0176ee7aa3)
2024-11-05 19:28:09 +08:00
Xyfacai
f0907bf60a fix: 部分情况缺少返回预扣
(cherry picked from commit 96373455521a38095706bd81c57f9a18557d9c2e)
2024-11-05 19:28:08 +08:00
1808837298@qq.com
e5c05d77b7 feat: realtime pre consume
(cherry picked from commit 273d154e1640bae26b7caedddf1685e9ff21ab74)
2024-11-05 19:28:06 +08:00
1808837298@qq.com
24b3ed50d7 feat: realtime pre consume
(cherry picked from commit d87917f8f6eb9d2e144a9f840d6d91767ea2eb69)
2024-11-05 19:28:03 +08:00
1808837298@qq.com
8de79382f0 feat: azure realtime
(cherry picked from commit 75ff3d98f06103dc2df1f8817bd3fcbf433e0f20)
2024-11-05 19:27:55 +08:00
1808837298@qq.com
74f9006b40 feat: realtime
(cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425)
2024-11-05 19:27:47 +08:00
1808837298@qq.com
33af069fae feat: realtime
(cherry picked from commit a5529df3e1a4c08a120e8c05203a7d885b0fe8d8)
2024-11-05 19:24:14 +08:00
1808837298@qq.com
e3c85572d4 Update dto
(cherry picked from commit 030187ff75c64c40017cda2fa98ef2b3c01f0bd5)
2024-11-05 19:23:56 +08:00
CalciumIon
4b48e490fa feat: 添加Mistral渠道 (close #546) 2024-11-05 17:11:33 +08:00
CalciumIon
3e2ae29ba0 fix: 修复聊天环境变量替换不完全 (close #542) 2024-11-05 16:02:10 +08:00
CalciumIon
fe0ed128c6 chore: update model ratio 2024-11-05 15:58:22 +08:00
Calcium-Ion
3785e9d754 Merge pull request #549 from HynoR/main
chore: 更新最新haiku模型倍率
2024-11-05 14:58:55 +08:00
HynoR
902a66b60f Sync Latest Claude Model 2024-11-05 10:17:11 +08:00
Calcium-Ion
aaf3f09eec Merge pull request #548 from utopeadia/main
ollama /api/embeddings is deprecated, use /api/embed.
2024-11-04 22:21:52 +08:00
HowieWood
e523555844 /api/embeddings is deprecated, use /api/embed.
/api/embeddings is deprecated, use /api/embed.
2024-11-04 22:03:41 +08:00
CalciumIon
139a104b26 feat: support gpt-4o-audio-preview 2024-11-04 15:27:12 +08:00
1808837298@qq.com
8b8abfadaf Merge remote-tracking branch 'origin/main' 2024-10-24 00:19:18 +08:00
1808837298@qq.com
65e65097b2 feat: aws claude tools 2024-10-24 00:19:08 +08:00
Calcium-Ion
62e321fe30 Merge pull request #533 from HynoR/main
chore: 修正chatgpt-4o-latest补全倍率
2024-10-24 00:17:48 +08:00
1808837298@qq.com
312ab44800 feat: update claude models 2024-10-24 00:17:23 +08:00
TAKO
a2678a256d Update model-ratio.go
fix wrong model ration about chatgpt-4o-latest
2024-10-17 12:50:14 +08:00
1808837298@qq.com
8b67664995 feat: 上游渠道为OpenAI渠道类型时,透传请求 (close #532) 2024-10-15 18:37:44 +08:00
1808837298@qq.com
ade6d0f56a fix: 修复Playground分组无用户分组 (close #529) 2024-10-14 16:22:38 +08:00
1808837298@qq.com
f599c65944 fix: 修复用户可选分组不能选择用户分组 (close #528) 2024-10-14 16:22:22 +08:00
1808837298@qq.com
40baa636e4 fix: 修复自定义聊天bug
(cherry picked from commit 8d41c17ccf19cb29100dbe506d3d42a6be822ff9)
2024-10-13 00:21:52 +08:00
1808837298@qq.com
d6359ec4ff feat: 完善自定义聊天配置 2024-10-12 21:09:59 +08:00
1808837298@qq.com
89ddf83b44 feat: 弃用旧的聊天配置 2024-10-12 21:09:59 +08:00
1808837298@qq.com
6a8a4bcf65 fix: playground group 2024-10-10 13:39:09 +08:00
1808837298@qq.com
e298f2e5a4 feat: playground token name 2024-10-10 13:34:29 +08:00
1808837298@qq.com
8cea6dff4a feat: support embedding encoding_format param 2024-10-10 13:23:12 +08:00
1808837298@qq.com
5035cd054a feat: update aws claude 2024-10-09 00:42:36 +08:00
1808837298@qq.com
02c0c6501e feat: update auto disable 2024-10-08 23:15:57 +08:00
1808837298@qq.com
f0b808a41d feat: update model ratio 2024-10-03 21:12:09 +08:00
1808837298@qq.com
31d84ee32f feat: update model ratio 2024-10-03 20:48:47 +08:00
1808837298@qq.com
9969ed2d7c feat: update model ratio 2024-10-03 20:47:54 +08:00
1808837298@qq.com
746311242b fix: playground气泡溢出 #511 2024-09-27 20:49:26 +08:00
1808837298@qq.com
04a68a85dd feat: 优化playground样式 2024-09-27 20:49:25 +08:00
1808837298@qq.com
f9ba10f180 fix: playground max_tokens #512 #511 2024-09-27 20:18:53 +08:00
Calcium-Ion
334a6f8280 Update README.md 2024-09-26 01:54:33 +08:00
1808837298@qq.com
0cf53ac5ff feat: Playground相关接口禁用AccessToken 2024-09-26 01:49:35 +08:00
Calcium-Ion
af02cdc58b Merge pull request #509 from Calcium-Ion/playground
feat: playground
2024-09-26 01:00:33 +08:00
1808837298@qq.com
9a4ca1e210 feat: playground 2024-09-26 00:59:09 +08:00
1808837298@qq.com
9fe1f35fd1 fix: 第三方登录注销 #500 2024-09-25 17:15:59 +08:00
1808837298@qq.com
972ac1ee0f fix: 第三方登录注销 #500 2024-09-25 17:13:28 +08:00
1808837298@qq.com
0f95502b04 feat: 更新令牌生成算法 2024-09-25 16:31:25 +08:00
1808837298@qq.com
b58b1dc0ec feat: 更新令牌生成算法 2024-09-25 16:31:25 +08:00
1808837298@qq.com
05d9aa61df feat: 不自动生成系统访问令牌 2024-09-25 16:31:25 +08:00
1808837298@qq.com
221894d972 fix: error user role 2024-09-24 17:49:57 +08:00
1808837298@qq.com
50eab6b4e4 chore: 更新令牌分组描述 2024-09-22 19:43:06 +08:00
1808837298@qq.com
ed972eef06 feat: pricing page support multi groups #487 2024-09-22 17:44:57 +08:00
CalciumIon
c6ff785a83 feat: 无可选分组时关闭令牌分组功能 #485 2024-09-19 03:01:33 +08:00
CalciumIon
2e734e0c37 chore: 令牌分组描述歧义 2024-09-19 02:52:25 +08:00
CalciumIon
af33f36c7b feat: update gemini flash completion ratio #479 2024-09-18 20:39:06 +08:00
CalciumIon
3aa86a8cd9 feat: update gemini completion ratio #479 2024-09-18 20:37:22 +08:00
CalciumIon
af7fecbfa7 fix: 使用令牌分组时 "/v1/models" 返回模型不正确 #481 2024-09-18 19:19:37 +08:00
CalciumIon
3fbdd502b6 fix: token group #477 2024-09-18 18:55:11 +08:00
CalciumIon
052bc2075b feat: 令牌分组 2024-09-18 05:19:49 +08:00
Calcium-Ion
5f3798053f Create FUNDING.yml 2024-09-18 01:41:31 +08:00
CalciumIon
e31022c676 Update logo 2024-09-18 01:25:00 +08:00
Calcium-Ion
fff7609f06 Merge pull request #439 from guoruqiang/main
改进了聊天页面,增加了初始令牌,方便用户注册后即可使用聊天功能。
2024-09-17 23:14:19 +08:00
CalciumIon
9032b5cfbf fix: 初始令牌 2024-09-17 23:07:16 +08:00
CalciumIon
131453dac8 Update README.md 2024-09-17 23:01:34 +08:00
CalciumIon
ed948c121a Merge branch 'main' into g-main
# Conflicts:
#	web/src/App.js
2024-09-17 22:50:59 +08:00
CalciumIon
a03cd15505 fix: '/v1/models' #474 2024-09-17 22:41:54 +08:00
CalciumIon
02f5137781 fix: '/v1/models' #474 2024-09-17 22:39:58 +08:00
CalciumIon
e6df0ed20c fix: '/vi/models' #474 2024-09-17 22:36:20 +08:00
CalciumIon
f505afdc10 feat: 添加令牌ip白名单功能 2024-09-17 20:49:51 +08:00
CalciumIon
feb1d76942 feat: 优化界面显示 2024-09-17 19:55:18 +08:00
CalciumIon
6263616cd9 Update README.md 2024-09-17 03:18:12 +08:00
GuoRuqiang
6bbf1d4843 Merge branch 'Calcium-Ion:main' into main 2024-09-14 19:00:03 +08:00
1808837298@qq.com
13c993d87e feat: format o1 model max tokens param 2024-09-14 16:11:38 +08:00
CalciumIon
cb73889353 feat: support o1 channel test 2024-09-13 03:17:04 +08:00
CalciumIon
804aad3f37 feat: support o1 channel test 2024-09-13 03:15:32 +08:00
CalciumIon
3af62a3efa feat: support OpenAI o1-preview and o1-mini 2024-09-13 01:22:27 +08:00
CalciumIon
be54369c12 chore: update footer 2024-09-12 18:43:01 +08:00
CalciumIon
0cbf8e07e7 feat: support ollama multi-text embedding 2024-09-12 18:29:45 +08:00
Calcium-Ion
1675679be9 Merge pull request #464 from Yan-Zero/main
fix: tool use in claude and add gemini mapping
2024-09-12 05:04:19 +08:00
Yan
0b5f2a7089 add gemini exp 2024-09-11 19:37:03 +08:00
Yan Tau
b5bb708072 Merge branch 'Calcium-Ion:main' into main 2024-09-11 19:29:50 +08:00
CalciumIon
2650ec9b59 feat: claude response return model name 2024-09-11 19:12:55 +08:00
CalciumIon
d168a685c1 fix: cohere SafetyMode 2024-09-11 19:12:32 +08:00
GuoRuqiang
a0d20896b3 Merge branch 'Calcium-Ion:main' into main 2024-09-08 15:56:54 +08:00
Calcium-Ion
5cab06d1ce Merge pull request #459 from HynoR/main
chore: 适配cohere的safety参数
2024-09-05 18:37:47 +08:00
CalciumIon
e3b3fdec48 feat: update chatgpt-4o token encoder 2024-09-05 18:35:34 +08:00
CalciumIon
5863aa8061 feat: remove lobe chat link #457 2024-09-05 18:34:04 +08:00
Yan
0ada2371b6 fix: tool use in claude 2024-09-05 00:53:00 +08:00
CalciumIon
8bc1e956cf fix: email 2024-09-04 19:44:29 +08:00
GuoRuqiang
a0673ef2b6 Merge branch 'Calcium-Ion:main' into main 2024-09-02 21:53:54 +08:00
HynoR
416f831a6c Merge remote-tracking branch 'origin/main' 2024-09-02 06:47:58 +07:00
HynoR
0b4317ce28 Update Cohere Safety Setting 2024-09-02 06:47:49 +07:00
Calcium-Ion
12e2481acb Merge pull request #451 from Nana7mi1/main
feat: support more zhipu models
2024-09-02 01:12:10 +08:00
Calcium-Ion
270709064d Merge pull request #455 from HynoR/feat/cohere-update
Feat: 更新Cohere新模型和定价
2024-09-02 01:11:55 +08:00
CalciumIon
0830ef3305 feat: support jina embedding 2024-09-02 01:11:19 +08:00
HynoR
722cc174b7 Cohere Update 2024-09-01 15:21:05 +07:00
Nanami
97c18d0c7f feat: support more zhipu models 2024-08-31 10:20:22 +08:00
GuoRuqiang
2223aeb022 Merge branch 'Calcium-Ion:main' into main 2024-08-29 19:42:03 +08:00
CalciumIon
4b1e83c42d feat: support siliconflow embedding #447 2024-08-29 00:19:30 +08:00
GuoRuqiang
ecf2f7f212 Merge branch 'Calcium-Ion:main' into main 2024-08-28 21:44:54 +08:00
CalciumIon
01fd8b53a6 feat: 检测vertex渠道部署地区是否填写 2024-08-28 18:47:27 +08:00
CalciumIon
e60f200192 feat: 支持vertex ai渠道多个部署地区 2024-08-28 18:43:40 +08:00
GuoRuqiang
033359e93c Merge branch 'Calcium-Ion:main' into main 2024-08-28 10:44:14 +08:00
CalciumIon
c41820541d Update go.mod 2024-08-27 20:30:46 +08:00
CalciumIon
228f0c5ee5 Update README.md 2024-08-27 20:25:55 +08:00
Calcium-Ion
8a5e074f14 Merge pull request #448 from Calcium-Ion/vertex
feat: support vertex ai
2024-08-27 20:21:01 +08:00
CalciumIon
ac4262c542 feat: support vertex ai #377 2024-08-27 20:19:51 +08:00
GuoRuqiang
1379d7f184 Merge pull request #2 from j471782517/main
增加环境变量GENERATE_DEFAULT_TOKEN 设置之后将生成初始令牌,默认关闭。
2024-08-25 02:53:47 +08:00
Jin Weihan
716bf6f48a 增加环境变量GENERATE_DEFAULT_TOKEN 设置之后将生成初始令牌,默认关闭。 2024-08-24 18:44:37 +00:00
GuoRuqiang
2422eb2820 Merge branch 'Calcium-Ion:main' into main 2024-08-25 01:55:23 +08:00
CalciumIon
46e03683ce fix: channel auto ban 2024-08-24 17:27:14 +08:00
CalciumIon
ff0985f06e fix: channel auto ban #443 2024-08-24 17:23:24 +08:00
CalciumIon
a8ac8a25d5 feat: format claude messages when first role is not user 2024-08-24 17:15:55 +08:00
Xyfacai
5b2082ba58 Merge branch 'main' of https://github.com/Calcium-Ion/new-api 2024-08-24 13:36:44 +08:00
Xyfacai
967ccabb56 fix: 修复 dall-e-2 请求报错 2024-08-24 13:36:41 +08:00
CalciumIon
144513f1d8 feat: rerank model mapping (close #444) 2024-08-23 23:21:37 +08:00
Calcium-Ion
e3087e9bea Merge pull request #445 from OswinWu/fix-outlook-ofb
fix: 多地区outlook邮箱和ofb邮箱Auth
2024-08-23 23:16:37 +08:00
OswinWu
484a8595e4 fix: 多地区outlook邮箱和ofb邮箱Auth 2024-08-23 17:16:09 +08:00
GuoRuqiang
c97e2875b4 增加注册自动生成初始令牌。 2024-08-18 15:12:59 +00:00
GuoRuqiang
64794630c8 修改提示时间。 2024-08-17 16:59:31 +00:00
GuoRuqiang
fc5055c766 update App.js 2024-08-17 16:20:41 +00:00
GuoRuqiang
27eb358497 重新修改了chat 2024-08-17 16:17:24 +00:00
GuoRuqiang
6810ee0a28 Update Chat
修改chat界面,配合nextChat等前端可以自动传入第一个已启用令牌,
2024-08-17 23:09:45 +08:00
CalciumIon
7c4d9d225e feat: support SiliconFlow (close #437, close #403) 2024-08-16 18:27:26 +08:00
CalciumIon
d0f76a5c61 feat: support gpt-4o-gizmo-* (close #436) 2024-08-16 17:25:03 +08:00
CalciumIon
a5ec11e463 fix: add email missing Message-ID 2024-08-16 16:16:38 +08:00
CalciumIon
b3d8e3e9ae fix: lobechat #430 2024-08-16 14:59:32 +08:00
CalciumIon
0c46d0c7af chore: remove useless code 2024-08-14 22:44:33 +08:00
CalciumIon
8cd8cc29bc fix: log page 'Cannot read properties of undefined (reading 'length')' 2024-08-14 22:43:57 +08:00
CalciumIon
748e34fd10 feat: update openai models list 2024-08-14 15:51:48 +08:00
CalciumIon
f9392ca904 feat: 避免暴露内部错误 2024-08-14 15:49:33 +08:00
CalciumIon
1988c41842 feat: update chatgpt-4o-latest model ratio 2024-08-14 15:47:08 +08:00
CalciumIon
6cb0eb4b39 feat: update claude tools calling 2024-08-13 17:54:24 +08:00
Calcium-Ion
59d06a5576 Merge pull request #427 from QuentinHsu/fix-log-pagination
fix log pagination
2024-08-13 17:50:12 +08:00
Calcium-Ion
1b900e3917 Merge pull request #426 from OswinWu/fix-log-page
Fix log page
2024-08-13 17:50:03 +08:00
Calcium-Ion
accbae3904 Merge pull request #432 from xixingya/feat-add-logdb
Feature: Support Log DB
2024-08-13 17:48:25 +08:00
liuzhifei
d82bd20354 support log db 2024-08-13 10:29:55 +08:00
liuzhifei
0c01f49bc5 add log db 2024-08-13 10:28:35 +08:00
QuentinHsu
9edb7c4ade fix: log pagination 2024-08-11 11:25:32 +08:00
Nothing.
228104e848 Merge branch 'Calcium-Ion:main' into fix-log-page 2024-08-11 11:22:34 +08:00
OswinWu
a2af637e7f fix: log分页问题 2024-08-11 11:21:34 +08:00
QuentinHsu
d6f6403fd3 chore: update @so1ve/prettier-config to version 3.1.0 2024-08-11 11:18:08 +08:00
CalciumIon
4b5303a77b feat: 区分额度不足和预扣费失败提示 2024-08-09 18:48:13 +08:00
CalciumIon
6eab0cc370 feat: 区分额度不足和预扣费失败提示 2024-08-09 18:34:51 +08:00
CalciumIon
9e45dbe964 fix: close #422 2024-08-09 16:14:05 +08:00
Calcium-Ion
e495354823 Merge pull request #425 from dalefengs/fix_group
fix: 渠道多分组查询 sqlite 查询兼容
2024-08-09 15:44:23 +08:00
FENG
9452be51b9 fix: sqlite group 查询兼容 2024-08-09 11:39:19 +08:00
Calcium-Ion
43076c2f33 Merge pull request #415 from dalefengs/fix_group
fix: 渠道多分组,优化分组 like 查询
2024-08-08 20:47:51 +08:00
CalciumIon
04f0084d97 fix: 修复mysql兼容问题 2024-08-08 20:45:41 +08:00
CalciumIon
2e3c266bd6 fix: response format 2024-08-07 15:43:01 +08:00
FENG
e614ca370a fix: optionList bug 2024-08-06 21:30:20 +08:00
FENG
c152b4de08 chore: indent recovery 2024-08-06 15:40:44 +08:00
FENG
190316f66e fix: 渠道多分组,优化分组 like 查询 2024-08-05 22:35:16 +08:00
147 changed files with 6944 additions and 1396 deletions

12
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -1,6 +1,13 @@
<div align="center">
![new-api](/web/public/logo.png)
# New API # New API
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
> [!NOTE] > [!NOTE]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发 > 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
@@ -41,6 +48,7 @@
4. Telegram Bot 名称是bot username 去掉@后的字符串 4. Telegram Bot 名称是bot username 去掉@后的字符串
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md) 13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
14. 支持Rerank模型目前仅兼容Cohere和Jina可接入Dify[对接文档](Rerank.md) 14. 支持Rerank模型目前仅兼容Cohere和Jina可接入Dify[对接文档](Rerank.md)
15. **[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API支持Azure渠道。
## 模型支持 ## 模型支持
此版本额外支持以下模型: 此版本额外支持以下模型:
@@ -54,10 +62,12 @@
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md) 8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md) 9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md)
10. Dify 10. Dify
11. Vertex AI目前兼容ClaudeGeminiLlama3.1
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。 您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 比原版One API多出的配置 ## 比原版One API多出的配置
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` - `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。 - `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
@@ -65,7 +75,7 @@
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true` - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。 - `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置 - `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL``STRICT`,默认为 `NONE`
## 部署 ## 部署
### 部署要求 ### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机) - 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
@@ -114,24 +124,19 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
## Suno接口设置文档 ## Suno接口设置文档
[对接文档](Suno.md) [对接文档](Suno.md)
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
## 界面截图 ## 界面截图
![796df8d287b7b7bd7853b2497e7df511](https://github.com/user-attachments/assets/255b5e97-2d3a-4434-b4fa-e922ad88ff5a)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/d1ac216e-0804-4105-9fdc-66b35022d861) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/f4f40ed4-8ccb-43d7-a580-90677827646d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/90d7d763-6a77-4b36-9f76-2bb30f18583d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/e414228a-3c35-429a-b298-6451d76d9032)
夜间模式 夜间模式
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/5b3228e8-2556-44f7-97d6-4f8d8ee6effa)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="200">
## 相关项目 ## 相关项目
- [One API](https://github.com/songquanpeng/one-api):原版项目 - [One API](https://github.com/songquanpeng/one-api):原版项目
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)Midjourney接口支持 - [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)Midjourney接口支持

View File

@@ -112,6 +112,9 @@ var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
const ( const (
RequestIdKey = "X-Oneapi-Request-Id" RequestIdKey = "X-Oneapi-Request-Id"
) )
@@ -123,6 +126,10 @@ const (
RoleRootUser = 100 RoleRootUser = 100
) )
func IsValidateRole(role int) bool {
return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
}
var ( var (
FileUploadPermission = RoleGuestUser FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser FileDownloadPermission = RoleGuestUser
@@ -213,6 +220,9 @@ const (
ChannelTypeDify = 37 ChannelTypeDify = 37
ChannelTypeJina = 38 ChannelTypeJina = 38
ChannelCloudflare = 39 ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDummy // this one is only for count, do not add any channel after this ChannelTypeDummy // this one is only for count, do not add any channel after this
@@ -259,4 +269,7 @@ var ChannelBaseURLs = []string{
"", //37 "", //37
"https://api.jina.ai", //38 "https://api.jina.ai", //38
"https://api.cloudflare.com", //39 "https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
"https://api.mistral.ai", //42
} }

View File

@@ -3,6 +3,7 @@ package common
import ( import (
"errors" "errors"
"net/smtp" "net/smtp"
"strings"
) )
type outlookAuth struct { type outlookAuth struct {
@@ -30,3 +31,10 @@ func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
} }
return nil, nil return nil, nil
} }
func isOutlookServer(server string) bool {
// 兼容多地区的outlook邮箱和ofb邮箱
// 其实应该加一个Option来区分是否用LOGIN的方式登录
// 先临时兼容一下
return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
}

View File

@@ -9,17 +9,26 @@ import (
"time" "time"
) )
func generateMessageID() string {
domain := strings.Split(SMTPAccount, "@")[1]
return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain)
}
func SendEmail(subject string, receiver string, content string) error { func SendEmail(subject string, receiver string, content string) error {
if SMTPFrom == "" { // for compatibility if SMTPFrom == "" { // for compatibility
SMTPFrom = SMTPAccount SMTPFrom = SMTPAccount
} }
if SMTPServer == "" && SMTPAccount == "" {
return fmt.Errorf("SMTP 服务器未配置")
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
mail := []byte(fmt.Sprintf("To: %s\r\n"+ mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+ "From: %s<%s>\r\n"+
"Subject: %s\r\n"+ "Subject: %s\r\n"+
"Date: %s\r\n"+ "Date: %s\r\n"+
"Message-ID: %s\r\n"+ // 添加 Message-ID 头
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), content)) receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), generateMessageID(), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")
@@ -62,7 +71,7 @@ func SendEmail(subject string, receiver string, content string) error {
if err != nil { if err != nil {
return err return err
} }
} else if strings.HasSuffix(SMTPAccount, "outlook.com") { } else if isOutlookServer(SMTPAccount) {
auth = LoginAuth(SMTPAccount, SMTPToken) auth = LoginAuth(SMTPAccount, SMTPToken)
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
} else { } else {

View File

@@ -23,28 +23,36 @@ const (
var defaultModelRatio = map[string]float64{ var defaultModelRatio = map[string]float64{
//"midjourney": 50, //"midjourney": 50,
"gpt-4-gizmo-*": 15, "gpt-4-gizmo-*": 15,
"gpt-4-all": 15, "gpt-4o-gizmo-*": 2.5,
"gpt-4o-all": 15, "gpt-4-all": 15,
"gpt-4": 15, "gpt-4o-all": 15,
"gpt-4": 15,
//"gpt-4-0314": 15, //deprecated //"gpt-4-0314": 15, //deprecated
"gpt-4-0613": 15, "gpt-4-0613": 15,
"gpt-4-32k": 30, "gpt-4-32k": 30,
//"gpt-4-32k-0314": 30, //deprecated //"gpt-4-32k-0314": 30, //deprecated
"gpt-4-32k-0613": 30, "gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens "gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.01 / 1K tokens "chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens "gpt-4o": 1.25, // $0.01 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens "gpt-4o-audio-preview": 1.25, // $0.0015 / 1K tokens
"gpt-4o-mini": 0.075, "gpt-4o-audio-preview-2024-10-01": 1.25, // $0.0015 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075, "gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
"gpt-4-turbo": 5, // $0.01 / 1K tokens "gpt-4o-2024-05-13": 2.5,
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "gpt-4o-realtime-preview": 2.5,
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens "o1-preview": 7.5,
"o1-preview-2024-09-12": 7.5,
"o1-mini": 1.5,
"o1-mini-2024-09-12": 1.5,
"gpt-4o-mini": 0.075,
"gpt-4o-mini-2024-07-18": 0.075,
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
//"gpt-3.5-turbo-0301": 0.75, //deprecated //"gpt-3.5-turbo-0301": 0.75, //deprecated
"gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
@@ -80,8 +88,10 @@ var defaultModelRatio = map[string]float64{
"claude-2.0": 4, // $8 / 1M tokens "claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens "claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens "claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens "claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5, "claude-3-5-sonnet-20240620": 1.5,
"claude-3-5-sonnet-20241022": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens "claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB, "ERNIE-3.5-8K": 0.012 * RMB,
@@ -104,8 +114,10 @@ var defaultModelRatio = map[string]float64{
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1, "gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1, "gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1, "gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1, "gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1, "gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1, "gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1, "gemini-ultra": 1,
@@ -117,6 +129,13 @@ var defaultModelRatio = map[string]float64{
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens "glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens "glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572, "glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens "qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
@@ -135,26 +154,28 @@ var defaultModelRatio = map[string]float64{
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元 // https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格 // 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.18, "yi-34b-chat-0205": 0.18,
"yi-34b-chat-200k": 0.864, "yi-34b-chat-200k": 0.864,
"yi-vl-plus": 0.432, "yi-vl-plus": 0.432,
"yi-large": 20.0 / 1000 * RMB, "yi-large": 20.0 / 1000 * RMB,
"yi-medium": 2.5 / 1000 * RMB, "yi-medium": 2.5 / 1000 * RMB,
"yi-vision": 6.0 / 1000 * RMB, "yi-vision": 6.0 / 1000 * RMB,
"yi-medium-200k": 12.0 / 1000 * RMB, "yi-medium-200k": 12.0 / 1000 * RMB,
"yi-spark": 1.0 / 1000 * RMB, "yi-spark": 1.0 / 1000 * RMB,
"yi-large-rag": 25.0 / 1000 * RMB, "yi-large-rag": 25.0 / 1000 * RMB,
"yi-large-turbo": 12.0 / 1000 * RMB, "yi-large-turbo": 12.0 / 1000 * RMB,
"yi-large-preview": 20.0 / 1000 * RMB, "yi-large-preview": 20.0 / 1000 * RMB,
"yi-large-rag-preview": 25.0 / 1000 * RMB, "yi-large-rag-preview": 25.0 / 1000 * RMB,
"command": 0.5, "command": 0.5,
"command-nightly": 0.5, "command-nightly": 0.5,
"command-light": 0.5, "command-light": 0.5,
"command-light-nightly": 0.5, "command-light-nightly": 0.5,
"command-r": 0.25, "command-r": 0.25,
"command-r-plus ": 1.5, "command-r-plus": 1.5,
"deepseek-chat": 0.07, "command-r-08-2024": 0.075,
"deepseek-coder": 0.07, "command-r-plus-08-2024": 1.25,
"deepseek-chat": 0.07,
"deepseek-coder": 0.07,
// Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用 // Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用
"llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD, "llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD,
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD, "llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
@@ -186,8 +207,8 @@ var defaultModelPrice = map[string]float64{
} }
var ( var (
modelPriceMap = make(map[string]float64) modelPriceMap map[string]float64 = nil
modelPriceMapMutex = sync.RWMutex{} modelPriceMapMutex = sync.RWMutex{}
) )
var ( var (
modelRatioMap map[string]float64 = nil modelRatioMap map[string]float64 = nil
@@ -196,8 +217,9 @@ var (
var CompletionRatio map[string]float64 = nil var CompletionRatio map[string]float64 = nil
var defaultCompletionRatio = map[string]float64{ var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2, "gpt-4-gizmo-*": 2,
"gpt-4-all": 2, "gpt-4o-gizmo-*": 3,
"gpt-4-all": 2,
} }
func GetModelPriceMap() map[string]float64 { func GetModelPriceMap() map[string]float64 {
@@ -231,6 +253,9 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} }
if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
price, ok := modelPriceMap[name] price, ok := modelPriceMap[name]
if !ok { if !ok {
if printErr { if printErr {
@@ -311,6 +336,34 @@ func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} }
if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
if strings.HasPrefix(name, "gpt-4o") {
if name == "gpt-4o-2024-05-13" {
return 3
}
return 4
}
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
return 3
}
return 2
}
if strings.HasPrefix(name, "o1-") {
return 4
}
if name == "chatgpt-4o-latest" {
return 3
}
if strings.Contains(name, "claude-instant-1") {
return 3
} else if strings.Contains(name, "claude-2") {
return 3
} else if strings.Contains(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "gpt-3.5") { if strings.HasPrefix(name, "gpt-3.5") {
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates // https://openai.com/blog/new-embedding-models-and-api-updates
@@ -322,30 +375,11 @@ func GetCompletionRatio(name string) float64 {
} }
return 4.0 / 3.0 return 4.0 / 3.0
} }
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
return 3
}
if strings.HasPrefix(name, "gpt-4o") {
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" {
return 4
}
return 3
}
return 2
}
if strings.Contains(name, "claude-instant-1") {
return 3
} else if strings.Contains(name, "claude-2") {
return 3
} else if strings.Contains(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "mistral-") { if strings.HasPrefix(name, "mistral-") {
return 3 return 3
} }
if strings.HasPrefix(name, "gemini-") { if strings.HasPrefix(name, "gemini-") {
return 3 return 4
} }
if strings.HasPrefix(name, "command") { if strings.HasPrefix(name, "command") {
switch name { switch name {
@@ -353,6 +387,10 @@ func GetCompletionRatio(name string) float64 {
return 3 return 3
case "command-r-plus": case "command-r-plus":
return 5 return 5
case "command-r-08-2024":
return 4
case "command-r-plus-08-2024":
return 4
default: default:
return 2 return 2
} }
@@ -383,6 +421,34 @@ func GetCompletionRatio(name string) float64 {
return 1 return 1
} }
func GetAudioRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4o-realtime") {
return 20
}
return 20
}
func GetAudioCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4o-realtime") {
return 10
}
return 2
}
//func GetAudioPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.06
// }
// return 0.06
//}
//
//func GetAudioCompletionPricePerMinute(name string) float64 {
// if strings.HasPrefix(name, "gpt-4o-realtime") {
// return 0.24
// }
// return 0.24
//}
func GetCompletionRatioMap() map[string]float64 { func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil { if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio CompletionRatio = defaultCompletionRatio

View File

@@ -31,14 +31,6 @@ func MapToJsonStr(m map[string]interface{}) string {
return string(bytes) return string(bytes)
} }
func MapToJsonStrFloat(m map[string]float64) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func StrToMap(str string) map[string]interface{} { func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{}) m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m) err := json.Unmarshal([]byte(str), &m)
@@ -48,6 +40,11 @@ func StrToMap(str string) map[string]interface{} {
return m return m
} }
func IsJsonStr(str string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func String2Int(str string) int { func String2Int(str string) int {
num, err := strconv.Atoi(str) num, err := strconv.Atoi(str)
if err != nil { if err != nil {

46
common/user_groups.go Normal file
View File

@@ -0,0 +1,46 @@
package common
import (
"encoding/json"
)
var UserUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(UserUsableGroups)
if err != nil {
SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
UserUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
}
func GetUserUsableGroups(userGroup string) map[string]string {
if userGroup == "" {
// 如果userGroup为空返回UserUsableGroups
return UserUsableGroups
}
// 如果userGroup不在UserUsableGroups中返回UserUsableGroups + userGroup
if _, ok := UserUsableGroups[userGroup]; !ok {
appendUserUsableGroups := make(map[string]string)
for k, v := range UserUsableGroups {
appendUserUsableGroups[k] = v
}
appendUserUsableGroups[userGroup] = "用户分组"
return appendUserUsableGroups
}
// 如果userGroup在UserUsableGroups中返回UserUsableGroups
return UserUsableGroups
}
func GroupInUserUsableGroups(groupName string) bool {
_, ok := UserUsableGroups[groupName]
return ok
}

View File

@@ -1,10 +1,13 @@
package common package common
import ( import (
crand "crypto/rand"
"encoding/base64"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"html/template" "html/template"
"log" "log"
"math/big"
"math/rand" "math/rand"
"net" "net"
"os/exec" "os/exec"
@@ -128,6 +131,11 @@ func IntMax(a int, b int) int {
} }
} }
func IsIP(s string) bool {
ip := net.ParseIP(s)
return ip != nil
}
func GetUUID() string { func GetUUID() string {
code := uuid.New().String() code := uuid.New().String()
code = strings.Replace(code, "-", "", -1) code = strings.Replace(code, "-", "", -1)
@@ -137,24 +145,35 @@ func GetUUID() string {
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() { func init() {
rand.Seed(time.Now().UnixNano()) rand.New(rand.NewSource(time.Now().UnixNano()))
} }
func GenerateKey() string { func GenerateRandomCharsKey(length int) (string, error) {
//rand.Seed(time.Now().UnixNano()) b := make([]byte, length)
key := make([]byte, 48) maxI := big.NewInt(int64(len(keyChars)))
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))] for i := range b {
} n, err := crand.Int(crand.Reader, maxI)
uuid_ := GetUUID() if err != nil {
for i := 0; i < 32; i++ { return "", err
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
} }
key[i+16] = c b[i] = keyChars[n.Int64()]
} }
return string(key)
return string(b), nil
}
func GenerateRandomKey(length int) (string, error) {
bytes := make([]byte, length*3/4) // 对于48位的输出这里应该是36
if _, err := crand.Read(bytes); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(bytes), nil
}
func GenerateKey() (string, error) {
//rand.Seed(time.Now().UnixNano())
return GenerateRandomCharsKey(48)
} }
func GetRandomInt(max int) int { func GetRandomInt(max int) int {

35
constant/chat.go Normal file
View File

@@ -0,0 +1,35 @@
package constant
import (
"encoding/json"
"one-api/common"
)
var Chats = []map[string]string{
{
"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
},
{
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
},
{
"AMA 问天": "ama://set-api-key?server={address}&key={key}",
},
{
"OpenCat": "opencat://team/join?domain={address}&token={key}",
},
}
func UpdateChatsByJsonString(jsonString string) error {
Chats = make([]map[string]string, 0)
return json.Unmarshal([]byte(jsonString), &Chats)
}
func Chats2JsonString() string {
jsonBytes, err := json.Marshal(Chats)
if err != nil {
common.SysError("error marshalling chats: " + err.Error())
return "[]"
}
return string(jsonBytes)
}

View File

@@ -20,14 +20,16 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{ var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta", "gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta", "gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta", "gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta", "gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-flash-latest": "v1beta", "gemini-1.5-pro-exp-0827": "v1beta",
"gemini-1.5-flash-001": "v1beta", "gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash": "v1beta", "gemini-1.5-flash-exp-0827": "v1beta",
"gemini-ultra": "v1beta", "gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
} }
func InitEnv() { func InitEnv() {
@@ -44,3 +46,6 @@ func InitEnv() {
} }
} }
} }
// 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

View File

@@ -20,6 +20,7 @@ import (
"one-api/relay/constant" "one-api/relay/constant"
"one-api/service" "one-api/service"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@@ -81,8 +82,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
request := buildTestRequest() request := buildTestRequest(testModel)
request.Model = testModel
meta.UpstreamModelName = testModel meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
@@ -102,17 +102,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if err != nil { if err != nil {
return err, nil return err, nil
} }
if resp != nil && resp.StatusCode != http.StatusOK { var httpResp *http.Response
err := service.RelayErrorHandler(resp) if resp != nil {
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp)
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
}
} }
usage, respErr := adaptor.DoResponse(c, resp, meta) usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
if respErr != nil { if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), respErr return fmt.Errorf("%s", respErr.Error.Message), respErr
} }
if usage == nil { if usageA == nil {
return errors.New("usage is nil"), nil return errors.New("usage is nil"), nil
} }
usage := usageA.(*dto.Usage)
result := w.Result() result := w.Result()
respBody, err := io.ReadAll(result.Body) respBody, err := io.ReadAll(result.Body)
if err != nil { if err != nil {
@@ -141,17 +146,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return nil, nil return nil, nil
} }
func buildTestRequest() *dto.GeneralOpenAIRequest { func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest := &dto.GeneralOpenAIRequest{ testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later Model: "", // this will be set later
MaxTokens: 1, Stream: false,
Stream: false, }
if strings.HasPrefix(model, "o1-") {
testRequest.MaxCompletionTokens = 1
} else {
testRequest.MaxTokens = 1
} }
content, _ := json.Marshal("hi") content, _ := json.Marshal("hi")
testMessage := dto.Message{ testMessage := dto.Message{
Role: "user", Role: "user",
Content: content, Content: content,
} }
testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage) testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest return testRequest
} }
@@ -226,26 +236,22 @@ func testAllChannels(notify bool) error {
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
ban := false shouldBanChannel := false
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
ban = true
}
// request error disables the channel // request error disables the channel
if openaiWithStatusErr != nil { if openaiWithStatusErr != nil {
oaiErr := openaiWithStatusErr.Error oaiErr := openaiWithStatusErr.Error
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
} }
// parse *int to bool if milliseconds > disableThreshold {
if !channel.GetAutoBan() { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
ban = false shouldBanChannel = true
} }
// disable channel // disable channel
if ban && isChannelEnabled { if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
service.DisableChannel(channel.Id, channel.Name, err.Error()) service.DisableChannel(channel.Id, channel.Name, err.Error())
} }

View File

@@ -198,6 +198,28 @@ func AddChannel(c *gin.Context) {
} }
channel.CreatedTime = common.GetTimestamp() channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n") keys := strings.Split(channel.Key, "\n")
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
keys = []string{channel.Key}
}
channels := make([]model.Channel, 0, len(keys)) channels := make([]model.Channel, 0, len(keys))
for _, key := range keys { for _, key := range keys {
if key == "" { if key == "" {
@@ -297,6 +319,27 @@ func UpdateChannel(c *gin.Context) {
}) })
return return
} }
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
}
err = channel.Update() err = channel.Update()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -112,7 +112,9 @@ func GitHubOAuth(c *gin.Context) {
user := model.User{ user := model.User{
GitHubId: githubUser.Login, GitHubId: githubUser.Login,
} }
// IsGitHubIdAlreadyTaken is unscoped
if model.IsGitHubIdAlreadyTaken(user.GitHubId) { if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
// FillUserByGitHubId is scoped
err := user.FillUserByGitHubId() err := user.FillUserByGitHubId()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -121,6 +123,14 @@ func GitHubOAuth(c *gin.Context) {
}) })
return return
} }
// if user.Id == 0 , user has been deleted
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else { } else {
if common.RegisterEnabled { if common.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)

View File

@@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model"
) )
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
@@ -17,3 +18,22 @@ func GetGroups(c *gin.Context) {
"data": groupNames, "data": groupNames,
}) })
} }
func GetUserGroups(c *gin.Context) {
usableGroups := make(map[string]string)
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.CacheGetUserGroup(userId)
for groupName, _ := range common.GroupRatio {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := common.GetUserUsableGroups(userGroup)
if _, ok := userUsableGroups[groupName]; ok {
usableGroups[groupName] = userUsableGroups[groupName]
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": usableGroups,
})
}

View File

@@ -1,18 +1,19 @@
package controller package controller
import ( import (
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"github.com/gin-gonic/gin"
) )
func GetAllLogs(c *gin.Context) { func GetAllLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size")) pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 { if p < 1 {
p = 0 p = 1
} }
if pageSize < 0 { if pageSize < 0 {
pageSize = common.ItemsPerPage pageSize = common.ItemsPerPage
@@ -24,7 +25,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel) logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -35,16 +36,20 @@ func GetAllLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": logs, "data": map[string]any{
"items": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
}) })
return
} }
func GetUserLogs(c *gin.Context) { func GetUserLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size")) pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 { if p < 1 {
p = 0 p = 1
} }
if pageSize < 0 { if pageSize < 0 {
pageSize = common.ItemsPerPage pageSize = common.ItemsPerPage
@@ -58,7 +63,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize) logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -69,7 +74,12 @@ func GetUserLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": logs, "data": map[string]any{
"items": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
}) })
return return
} }

View File

@@ -63,6 +63,7 @@ func GetStatus(c *gin.Context) {
"default_collapse_sidebar": common.DefaultCollapseSidebar, "default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "", "enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "",
"mj_notify_enabled": constant.MjNotifyEnabled, "mj_notify_enabled": constant.MjNotifyEnabled,
"chats": constant.Chats,
}, },
}) })
return return

View File

@@ -137,31 +137,63 @@ func init() {
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
models := model.GetGroupModels(user.Group)
userOpenAiModels := make([]dto.OpenAIModels, 0) userOpenAiModels := make([]dto.OpenAIModels, 0)
permission := getPermission() permission := getPermission()
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok { modelLimitEnable := c.GetBool("token_model_limit_enabled")
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s]) if modelLimitEnable {
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else { } else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ tokenModelLimit = map[string]bool{}
Id: s, }
Object: "model", for allowModel, _ := range tokenModelLimit {
Created: 1626777600, if _, ok := openAIModelsMap[allowModel]; ok {
OwnedBy: "custom", userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
Permission: permission, } else {
Root: s, userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Parent: nil, Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: allowModel,
Parent: nil,
})
}
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "get user group failed",
}) })
return
}
group := userGroup
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
group = tokenGroup
}
models := model.GetGroupModels(group)
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: s,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: s,
Parent: nil,
})
}
} }
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{

View File

@@ -7,18 +7,11 @@ import (
) )
func GetPricing(c *gin.Context) { func GetPricing(c *gin.Context) {
userId := c.GetInt("id") pricing := model.GetPricing()
// if no login, get default group ratio
groupRatio := common.GetGroupRatio("default")
group, err := model.CacheGetUserGroup(userId)
if err == nil {
groupRatio = common.GetGroupRatio(group)
}
pricing := model.GetPricing(group)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"data": pricing, "data": pricing,
"group_ratio": groupRatio, "group_ratio": common.GroupRatio,
}) })
} }

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io" "io"
"log" "log"
"net/http" "net/http"
@@ -38,6 +39,67 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
return err return err
} }
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
default:
err = relay.TextHelper(c)
}
return err
}
func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode
defer func() {
if openaiErr != nil {
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}()
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
return
}
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
return
}
if playgroundRequest.Model == "" {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
return
}
c.Set("original_model", playgroundRequest.Model)
group := playgroundRequest.Group
userGroup := c.GetString("group")
if group == "" {
group = userGroup
} else {
if !common.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
}
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
Relay(c)
}
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path) relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)
@@ -82,6 +144,67 @@ func Relay(c *gin.Context) {
} }
} }
var upgrader = websocket.Upgrader{
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol则必须在此声明对应的 Protocol TODO add other protocol
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域
},
}
func WssRelay(c *gin.Context) {
// 将 HTTP 连接升级为 WebSocket 连接
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
defer ws.Close()
if err != nil {
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
service.WssError(c, ws, openaiErr.Error)
return
}
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
openaiErr = wssRequest(c, ws, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
service.WssError(c, ws, openaiErr.Error)
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id) addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)
@@ -89,6 +212,13 @@ func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.Op
return relayHandler(c, relayMode) return relayHandler(c, relayMode)
} }
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.WssHelper(c, ws)
}
func addUsedChannel(c *gin.Context, channelId int) { func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
@@ -121,6 +251,9 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
if openaiErr == nil { if openaiErr == nil {
return false return false
} }
if openaiErr.LocalError {
return false
}
if retryTimes <= 0 { if retryTimes <= 0 {
return false return false
} }
@@ -151,9 +284,6 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
// azure处理超时不重试 // azure处理超时不重试
return false return false
} }
if openaiErr.LocalError {
return false
}
if openaiErr.StatusCode/100 == 2 { if openaiErr.StatusCode/100 == 2 {
return false return false
} }

View File

@@ -5,6 +5,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"io" "io"
"net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"sort" "sort"
@@ -48,6 +49,13 @@ func TelegramBind(c *gin.Context) {
}) })
return return
} }
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
user.TelegramId = telegramId user.TelegramId = telegramId
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{

View File

@@ -123,10 +123,19 @@ func AddToken(c *gin.Context) {
}) })
return return
} }
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Key: common.GenerateKey(), Key: key,
CreatedTime: common.GetTimestamp(), CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(), AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
@@ -134,6 +143,8 @@ func AddToken(c *gin.Context) {
UnlimitedQuota: token.UnlimitedQuota, UnlimitedQuota: token.UnlimitedQuota,
ModelLimitsEnabled: token.ModelLimitsEnabled, ModelLimitsEnabled: token.ModelLimitsEnabled,
ModelLimits: token.ModelLimits, ModelLimits: token.ModelLimits,
AllowIps: token.AllowIps,
Group: token.Group,
} }
err = cleanToken.Insert() err = cleanToken.Insert()
if err != nil { if err != nil {
@@ -221,6 +232,8 @@ func UpdateToken(c *gin.Context) {
cleanToken.UnlimitedQuota = token.UnlimitedQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
cleanToken.ModelLimits = token.ModelLimits cleanToken.ModelLimits = token.ModelLimits
cleanToken.AllowIps = token.AllowIps
cleanToken.Group = token.Group
} }
err = cleanToken.Update() err = cleanToken.Update()
if err != nil { if err != nil {

View File

@@ -7,10 +7,12 @@ import (
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/constant"
) )
type LoginRequest struct { type LoginRequest struct {
@@ -66,6 +68,7 @@ func setupLogin(user *model.User, c *gin.Context) {
session.Set("username", user.Username) session.Set("username", user.Username)
session.Set("role", user.Role) session.Set("role", user.Role)
session.Set("status", user.Status) session.Set("status", user.Status)
session.Set("group", user.Group)
err := session.Save() err := session.Save()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -157,8 +160,9 @@ func Register(c *gin.Context) {
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": "数据库错误,请稍后重试",
}) })
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return return
} }
if exist { if exist {
@@ -186,6 +190,48 @@ func Register(c *gin.Context) {
}) })
return return
} }
// 获取插入后的用户ID
var insertedUser model.User
if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户注册失败或用户ID获取失败",
})
return
}
// 生成默认令牌
if constant.GenerateDefaultToken {
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成默认令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
// 生成默认令牌
token := model.Token{
UserId: insertedUser.Id, // 使用插入后的用户ID
Name: cleanUser.Username + "的初始令牌",
Key: key,
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: -1, // 永不过期
RemainQuota: 500000, // 示例额度
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "创建默认令牌失败",
})
return
}
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@@ -276,7 +322,18 @@ func GenerateAccessToken(c *gin.Context) {
}) })
return return
} }
user.AccessToken = common.GetUUID() // get rand int 28-32
randI := common.GetRandomInt(4)
key, err := common.GenerateRandomKey(29 + randI)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成失败",
})
common.SysError("failed to generate key: " + err.Error())
return
}
user.SetAccessToken(key)
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -582,6 +639,7 @@ func DeleteSelf(c *gin.Context) {
func CreateUser(c *gin.Context) { func CreateUser(c *gin.Context) {
var user model.User var user model.User
err := json.NewDecoder(c.Request.Body).Decode(&user) err := json.NewDecoder(c.Request.Body).Decode(&user)
user.Username = strings.TrimSpace(user.Username)
if err != nil || user.Username == "" || user.Password == "" { if err != nil || user.Username == "" || user.Password == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -629,8 +687,8 @@ func CreateUser(c *gin.Context) {
} }
type ManageRequest struct { type ManageRequest struct {
Username string `json:"username"` Id int `json:"id"`
Action string `json:"action"` Action string `json:"action"`
} }
// ManageUser Only admin user can do this // ManageUser Only admin user can do this
@@ -646,7 +704,7 @@ func ManageUser(c *gin.Context) {
return return
} }
user := model.User{ user := model.User{
Username: req.Username, Id: req.Id,
} }
// Fill attributes // Fill attributes
model.DB.Unscoped().Where(&user).First(&user) model.DB.Unscoped().Where(&user).First(&user)

View File

@@ -78,6 +78,13 @@ func WeChatAuth(c *gin.Context) {
}) })
return return
} }
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else { } else {
if common.RegisterEnabled { if common.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)

View File

@@ -7,31 +7,35 @@ type ResponseFormat struct {
} }
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"` TopP float64 `json:"top_p,omitempty"`
Stop any `json:"stop,omitempty"` TopK int `json:"top_k,omitempty"`
N int `json:"n,omitempty"` Stop any `json:"stop,omitempty"`
Input any `json:"input,omitempty"` N int `json:"n,omitempty"`
Instruction string `json:"instruction,omitempty"` Input any `json:"input,omitempty"`
Size string `json:"size,omitempty"` Instruction string `json:"instruction,omitempty"`
Functions any `json:"functions,omitempty"` Size string `json:"size,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` Functions any `json:"functions,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
Seed float64 `json:"seed,omitempty"` ResponseFormat any `json:"response_format,omitempty"`
Tools []ToolCall `json:"tools,omitempty"` EncodingFormat any `json:"encoding_format,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` Seed float64 `json:"seed,omitempty"`
User string `json:"user,omitempty"` Tools []ToolCall `json:"tools,omitempty"`
LogProbs bool `json:"logprobs,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"` User string `json:"user,omitempty"`
Dimensions int `json:"dimensions,omitempty"` LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
} }
type OpenAITools struct { type OpenAITools struct {
@@ -81,9 +85,10 @@ type Message struct {
} }
type MediaMessage struct { type MediaMessage struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text"` Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"` ImageUrl any `json:"image_url,omitempty"`
InputAudio any `json:"input_audio,omitempty"`
} }
type MessageImageUrl struct { type MessageImageUrl struct {
@@ -91,9 +96,15 @@ type MessageImageUrl struct {
Detail string `json:"detail"` Detail string `json:"detail"`
} }
type MessageInputAudio struct {
Data string `json:"data"` //base64
Format string `json:"format"`
}
const ( const (
ContentTypeText = "text" ContentTypeText = "text"
ContentTypeImageURL = "image_url" ContentTypeImageURL = "image_url"
ContentTypeInputAudio = "input_audio"
) )
func (m Message) StringContent() string { func (m Message) StringContent() string {
@@ -166,11 +177,19 @@ func (m Message) ParseContent() []MediaMessage {
}, },
}) })
} }
case ContentTypeInputAudio:
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: subObj["data"].(string),
Format: subObj["format"].(string),
},
})
}
} }
} }
return contentList return contentList
} }
return nil return nil
} }

View File

@@ -34,6 +34,7 @@ type OpenAITextResponseChoice struct {
type OpenAITextResponse struct { type OpenAITextResponse struct {
Id string `json:"id"` Id string `json:"id"`
Model string `json:"model"`
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"` Choices []OpenAITextResponseChoice `json:"choices"`

6
dto/playground.go Normal file
View File

@@ -0,0 +1,6 @@
package dto
type PlayGroundRequest struct {
Model string `json:"model,omitempty"`
Group string `json:"group,omitempty"`
}

97
dto/realtime.go Normal file
View File

@@ -0,0 +1,97 @@
package dto
const (
RealtimeEventTypeError = "error"
RealtimeEventTypeSessionUpdate = "session.update"
RealtimeEventTypeConversationCreate = "conversation.item.create"
RealtimeEventTypeResponseCreate = "response.create"
RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
)
const (
RealtimeEventTypeResponseDone = "response.done"
RealtimeEventTypeSessionUpdated = "session.updated"
RealtimeEventTypeSessionCreated = "session.created"
RealtimeEventResponseAudioDelta = "response.audio.delta"
RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta"
RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done"
RealtimeEventConversationItemCreated = "conversation.item.created"
)
type RealtimeEvent struct {
EventId string `json:"event_id"`
Type string `json:"type"`
//PreviousItemId string `json:"previous_item_id"`
Session *RealtimeSession `json:"session,omitempty"`
Item *RealtimeItem `json:"item,omitempty"`
Error *OpenAIError `json:"error,omitempty"`
Response *RealtimeResponse `json:"response,omitempty"`
Delta string `json:"delta,omitempty"`
Audio string `json:"audio,omitempty"`
}
type RealtimeResponse struct {
Usage *RealtimeUsage `json:"usage"`
}
type RealtimeUsage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails InputTokenDetails `json:"input_token_details"`
OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
}
type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
}
type OutputTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
}
type RealtimeSession struct {
Modalities []string `json:"modalities"`
Instructions string `json:"instructions"`
Voice string `json:"voice"`
InputAudioFormat string `json:"input_audio_format"`
OutputAudioFormat string `json:"output_audio_format"`
InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"`
TurnDetection interface{} `json:"turn_detection"`
Tools []RealTimeTool `json:"tools"`
ToolChoice string `json:"tool_choice"`
Temperature float64 `json:"temperature"`
//MaxResponseOutputTokens int `json:"max_response_output_tokens"`
}
type InputAudioTranscription struct {
Model string `json:"model"`
}
type RealTimeTool struct {
Type string `json:"type"`
Name string `json:"name"`
Description string `json:"description"`
Parameters any `json:"parameters"`
}
type RealtimeItem struct {
Id string `json:"id"`
Type string `json:"type"`
Status string `json:"status"`
Role string `json:"role"`
Content []RealtimeContent `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
CallId string `json:"call_id,omitempty"`
}
type RealtimeContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Audio string `json:"audio,omitempty"` // Base64-encoded audio bytes.
Transcript string `json:"transcript,omitempty"`
}

View File

@@ -1,14 +1,17 @@
package dto package dto
type RerankRequest struct { type RerankRequest struct {
Documents []any `json:"documents"` Documents []any `json:"documents"`
Query string `json:"query"` Query string `json:"query"`
Model string `json:"model"` Model string `json:"model"`
TopN int `json:"top_n"` TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`
} }
type RerankResponseDocument struct { type RerankResponseDocument struct {
Document any `json:"document"` Document any `json:"document,omitempty"`
Index int `json:"index"` Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"` RelevanceScore float64 `json:"relevance_score"`
} }

22
go.mod
View File

@@ -1,7 +1,9 @@
module one-api module one-api
// +heroku goVersion go1.18 // +heroku goVersion go1.18
go 1.18 go 1.21
toolchain go1.22.4
require ( require (
github.com/Calcium-Ion/go-epay v0.0.2 github.com/Calcium-Ion/go-epay v0.0.2
@@ -9,6 +11,7 @@ require (
github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sessions v0.0.5
@@ -24,7 +27,7 @@ require (
github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0 github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible github.com/shirou/gopsutil v3.21.11+incompatible
golang.org/x/crypto v0.21.0 golang.org/x/crypto v0.26.0
golang.org/x/image v0.15.0 golang.org/x/image v0.15.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2 gorm.io/driver/postgres v1.5.2
@@ -38,9 +41,8 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect
@@ -51,6 +53,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect
@@ -69,6 +72,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/stretchr/testify v1.9.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
@@ -76,10 +80,10 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.21.0 // indirect golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.7.0 // indirect golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.18.0 // indirect golang.org/x/sys v0.24.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.17.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

40
go.sum
View File

@@ -23,8 +23,8 @@ github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaU
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -37,6 +37,7 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
@@ -57,6 +58,7 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -81,7 +83,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -142,8 +145,11 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
@@ -172,7 +178,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@@ -191,18 +198,18 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -214,26 +221,27 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=

View File

@@ -42,6 +42,11 @@ func main() {
if err != nil { if err != nil {
common.FatalLog("failed to initialize database: " + err.Error()) common.FatalLog("failed to initialize database: " + err.Error())
} }
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
defer func() { defer func() {
err := model.CloseDB() err := model.CloseDB()
if err != nil { if err != nil {

View File

@@ -10,6 +10,17 @@ import (
"strings" "strings"
) )
func validUserInfo(username string, role int) bool {
// check username is empty
if strings.TrimSpace(username) == "" {
return false
}
if !common.IsValidateRole(role) {
return false
}
return true
}
func authHelper(c *gin.Context, minRole int) { func authHelper(c *gin.Context, minRole int) {
session := sessions.Default(c) session := sessions.Default(c)
username := session.Get("username") username := session.Get("username")
@@ -30,6 +41,14 @@ func authHelper(c *gin.Context, minRole int) {
} }
user := model.ValidateAccessToken(accessToken) user := model.ValidateAccessToken(accessToken)
if user != nil && user.Username != "" { if user != nil && user.Username != "" {
if !validUserInfo(user.Username, user.Role) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
// Token is valid // Token is valid
username = user.Username username = user.Username
role = user.Role role = user.Role
@@ -91,9 +110,19 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort() c.Abort()
return return
} }
if !validUserInfo(username.(string), role.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
c.Set("username", username) c.Set("username", username)
c.Set("role", role) c.Set("role", role)
c.Set("id", id) c.Set("id", id)
c.Set("group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
c.Next() c.Next()
} }
@@ -126,8 +155,27 @@ func RootAuth() func(c *gin.Context) {
} }
} }
func WssAuth(c *gin.Context) {
}
func TokenAuth() func(c *gin.Context) { func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
// 先检测是否为ws
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
// read sk from Sec-WebSocket-Protocol
key := c.Request.Header.Get("Sec-WebSocket-Protocol")
parts := strings.Split(key, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "openai-insecure-api-key") {
key = strings.TrimPrefix(part, "openai-insecure-api-key.")
break
}
}
c.Request.Header.Set("Authorization", "Bearer "+key)
}
key := c.Request.Header.Get("Authorization") key := c.Request.Header.Get("Authorization")
parts := make([]string, 0) parts := make([]string, 0)
key = strings.TrimPrefix(key, "Bearer ") key = strings.TrimPrefix(key, "Bearer ")
@@ -175,6 +223,8 @@ func TokenAuth() func(c *gin.Context) {
} else { } else {
c.Set("token_model_limit_enabled", false) c.Set("token_model_limit_enabled", false)
} }
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1]) c.Set("specific_channel_id", parts[1])

View File

@@ -22,6 +22,14 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
allowIpsMap := c.GetStringMap("allow_ips")
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
userId := c.GetInt("id") userId := c.GetInt("id")
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("specific_channel_id") channelId, ok := c.Get("specific_channel_id")
@@ -31,6 +39,20 @@ func Distribute() func(c *gin.Context) {
return return
} }
userGroup, _ := model.CacheGetUserGroup(userId) userGroup, _ := model.CacheGetUserGroup(userId)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if _, ok := common.GroupRatio[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
userGroup = tokenGroup
}
c.Set("group", userGroup) c.Set("group", userGroup)
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
@@ -148,6 +170,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error()) return nil, false, errors.New("无效的请求, " + err.Error())
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
modelRequest.Model = c.Query("model")
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable" modelRequest.Model = "text-moderation-stable"
@@ -199,6 +225,8 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeVertexAi:
c.Set("region", channel.Other)
case common.ChannelTypeXunfei: case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeGemini: case common.ChannelTypeGemini:

View File

@@ -36,6 +36,12 @@ func GetEnabledModels() []string {
return models return models
} }
func GetAllEnableAbilities() []Ability {
var abilities []Ability
DB.Find(&abilities, "enabled = ?", true)
return abilities
}
func getPriority(group string, model string, retry int) (int, error) { func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`" groupCol := "`group`"
trueVal := "1" trueVal := "1"

View File

@@ -270,6 +270,9 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
if strings.HasPrefix(model, "gpt-4-gizmo") { if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*" model = "gpt-4-gizmo-*"
} }
if strings.HasPrefix(model, "gpt-4o-gizmo") {
model = "gpt-4o-gizmo-*"
}
// if memory cache is disabled, get channel directly from database // if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled { if !common.MemoryCacheEnabled {

View File

@@ -106,16 +106,23 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
// 构造WHERE子句 // 构造WHERE子句
var whereClause string var whereClause string
var args []interface{} var args []interface{}
if group != "" { if group != "" && group != "null" {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " = ? AND " + modelsCol + " LIKE ?" var groupCondition string
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, group, "%"+model+"%") if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%", "%,"+group+",%")
} else { } else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?" whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%") args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
} }
// 执行查询 // 执行查询
err := baseQuery.Where(whereClause, args...).Find(&channels).Error err := baseQuery.Where(whereClause, args...).Order("priority desc").Find(&channels).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -3,11 +3,12 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common" "one-api/common"
"strings" "strings"
"time" "time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
) )
type Log struct { type Log struct {
@@ -38,7 +39,7 @@ const (
) )
func GetLogByKey(key string) (logs []*Log, err error) { func GetLogByKey(key string) (logs []*Log, err error) {
err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
return logs, err return logs, err
} }
@@ -54,7 +55,7 @@ func RecordLog(userId int, logType int, content string) {
Type: logType, Type: logType,
Content: content, Content: content,
} }
err := DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
common.SysError("failed to record log: " + err.Error()) common.SysError("failed to record log: " + err.Error())
} }
@@ -84,7 +85,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
IsStream: isStream, IsStream: isStream,
Other: otherStr, Other: otherStr,
} }
err := DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error()) common.LogError(ctx, "failed to record log: "+err.Error())
} }
@@ -95,12 +96,12 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
} }
} }
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = DB tx = LOG_DB
} else { } else {
tx = DB.Where("type = ?", logType) tx = LOG_DB.Where("type = ?", logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name like ?", modelName) tx = tx.Where("model_name like ?", modelName)
@@ -120,16 +121,23 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if channel != 0 { if channel != 0 {
tx = tx.Where("channel_id = ?", channel) tx = tx.Where("channel_id = ?", channel)
} }
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err if err != nil {
return nil, 0, err
}
return logs, total, err
} }
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = DB.Where("user_id = ?", userId) tx = LOG_DB.Where("user_id = ?", userId)
} else { } else {
tx = DB.Where("user_id = ? and type = ?", userId, logType) tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name like ?", modelName) tx = tx.Where("model_name like ?", modelName)
@@ -143,6 +151,10 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if endTimestamp != 0 { if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
for i := range logs { for i := range logs {
var otherMap map[string]interface{} var otherMap map[string]interface{}
@@ -153,16 +165,16 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
} }
logs[i].Other = common.MapToJsonStr(otherMap) logs[i].Other = common.MapToJsonStr(otherMap)
} }
return logs, err return logs, total, err
} }
func SearchAllLogs(keyword string) (logs []*Log, err error) { func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
return logs, err return logs, err
} }
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err return logs, err
} }
@@ -173,10 +185,10 @@ type Stat struct {
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := DB.Table("logs").Select("sum(quota) quota") tx := LOG_DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询 // 为rpm和tpm创建单独的查询
rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
@@ -193,8 +205,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name like ?", modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName) rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName)
} }
if channel != 0 { if channel != 0 {
tx = tx.Where("channel_id = ?", channel) tx = tx.Where("channel_id = ?", channel)
@@ -215,7 +227,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
} }
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }
@@ -236,6 +248,6 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
} }
func DeleteOldLog(targetTimestamp int64) (int64, error) { func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }

View File

@@ -15,6 +15,8 @@ import (
var DB *gorm.DB var DB *gorm.DB
var LOG_DB *gorm.DB
func createRootAccountIfNeed() error { func createRootAccountIfNeed() error {
var user User var user User
//if user.Status != common.UserStatusEnabled { //if user.Status != common.UserStatusEnabled {
@@ -30,7 +32,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser, Role: common.RoleRootUser,
Status: common.UserStatusEnabled, Status: common.UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: common.GetUUID(), AccessToken: nil,
Quota: 100000000, Quota: 100000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@@ -38,9 +40,9 @@ func createRootAccountIfNeed() error {
return nil return nil
} }
func chooseDB() (*gorm.DB, error) { func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv("SQL_DSN") != "" { dsn := os.Getenv(envName)
dsn := os.Getenv("SQL_DSN") if dsn != "" {
if strings.HasPrefix(dsn, "postgres://") { if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL // Use PostgreSQL
common.SysLog("using PostgreSQL as database") common.SysLog("using PostgreSQL as database")
@@ -52,6 +54,13 @@ func chooseDB() (*gorm.DB, error) {
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL // Use MySQL
common.SysLog("using MySQL as database") common.SysLog("using MySQL as database")
// check parseTime // check parseTime
@@ -76,7 +85,7 @@ func chooseDB() (*gorm.DB, error) {
} }
func InitDB() (err error) { func InitDB() (err error) {
db, err := chooseDB() db, err := chooseDB("SQL_DSN")
if err == nil { if err == nil {
if common.DebugEnabled { if common.DebugEnabled {
db = db.Debug() db = db.Debug()
@@ -100,52 +109,7 @@ func InitDB() (err error) {
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//} //}
common.SysLog("database migration started") common.SysLog("database migration started")
err = db.AutoMigrate(&Channel{}) err = migrateDB()
if err != nil {
return err
}
err = db.AutoMigrate(&Token{})
if err != nil {
return err
}
err = db.AutoMigrate(&User{})
if err != nil {
return err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return err
}
err = db.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = db.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = db.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
err = db.AutoMigrate(&Task{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err return err
} else { } else {
common.FatalLog(err) common.FatalLog(err)
@@ -153,8 +117,103 @@ func InitDB() (err error) {
return err return err
} }
func CloseDB() error { func InitLogDB() (err error) {
sqlDB, err := DB.DB() if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
db, err := chooseDB("LOG_SQL_DSN")
if err == nil {
if common.DebugEnabled {
db = db.Debug()
}
LOG_DB = db
sqlDB, err := LOG_DB.DB()
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil
}
//if common.UsingMySQL {
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//}
common.SysLog("database migration started")
err = migrateLOGDB()
return err
} else {
common.FatalLog(err)
}
return err
}
func migrateDB() error {
err := DB.AutoMigrate(&Channel{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Token{})
if err != nil {
return err
}
err = DB.AutoMigrate(&User{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Option{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Log{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = DB.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = DB.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Task{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func closeDB(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil { if err != nil {
return err return err
} }
@@ -162,6 +221,16 @@ func CloseDB() error {
return err return err
} }
func CloseDB() error {
if LOG_DB != DB {
err := closeDB(LOG_DB)
if err != nil {
return err
}
}
return closeDB(DB)
}
var ( var (
lastPingTime time.Time lastPingTime time.Time
pingMutex sync.Mutex pingMutex sync.Mutex

View File

@@ -69,6 +69,7 @@ func InitOptionMap() {
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64) common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp) common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = constant.Chats2JsonString()
common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = "" common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = "" common.OptionMap["TelegramBotToken"] = ""
@@ -86,6 +87,7 @@ func InitOptionMap() {
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["ChatLink"] = common.ChatLink
@@ -247,6 +249,8 @@ func updateOptionMap(key string, value string) (err error) {
constant.WorkerValidKey = value constant.WorkerValidKey = value
case "PayAddress": case "PayAddress":
constant.PayAddress = value constant.PayAddress = value
case "Chats":
err = constant.UpdateChatsByJsonString(value)
case "CustomCallbackAddress": case "CustomCallbackAddress":
constant.CustomCallbackAddress = value constant.CustomCallbackAddress = value
case "EpayId": case "EpayId":
@@ -303,6 +307,8 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = common.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = common.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio": case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value) err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice": case "ModelPrice":

View File

@@ -7,14 +7,13 @@ import (
) )
type Pricing struct { type Pricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"` ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"` QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"` ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"` ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"` OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"` CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"` EnableGroup []string `json:"enable_groups,omitempty"`
} }
var ( var (
@@ -23,40 +22,47 @@ var (
updatePricingLock sync.Mutex updatePricingLock sync.Mutex
) )
func GetPricing(group string) []Pricing { func GetPricing() []Pricing {
updatePricingLock.Lock() updatePricingLock.Lock()
defer updatePricingLock.Unlock() defer updatePricingLock.Unlock()
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing() updatePricing()
} }
if group != "" { //if group != "" {
userPricingMap := make([]Pricing, 0) // userPricingMap := make([]Pricing, 0)
models := GetGroupModels(group) // models := GetGroupModels(group)
for _, pricing := range pricingMap { // for _, pricing := range pricingMap {
if !common.StringsContains(models, pricing.ModelName) { // if !common.StringsContains(models, pricing.ModelName) {
pricing.Available = false // pricing.Available = false
} // }
userPricingMap = append(userPricingMap, pricing) // userPricingMap = append(userPricingMap, pricing)
} // }
return userPricingMap // return userPricingMap
} //}
return pricingMap return pricingMap
} }
func updatePricing() { func updatePricing() {
//modelRatios := common.GetModelRatios() //modelRatios := common.GetModelRatios()
enabledModels := GetEnabledModels() enableAbilities := GetAllEnableAbilities()
allModels := make(map[string]int) modelGroupsMap := make(map[string][]string)
for i, model := range enabledModels { for _, ability := range enableAbilities {
allModels[model] = i groups := modelGroupsMap[ability.Model]
if groups == nil {
groups = make([]string, 0)
}
if !common.StringsContains(groups, ability.Group) {
groups = append(groups, ability.Group)
}
modelGroupsMap[ability.Model] = groups
} }
pricingMap = make([]Pricing, 0) pricingMap = make([]Pricing, 0)
for model, _ := range allModels { for model, groups := range modelGroupsMap {
pricing := Pricing{ pricing := Pricing{
Available: true, ModelName: model,
ModelName: model, EnableGroup: groups,
} }
modelPrice, findPrice := common.GetModelPrice(model, false) modelPrice, findPrice := common.GetModelPrice(model, false)
if findPrice { if findPrice {

View File

@@ -6,6 +6,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
relaycommon "one-api/relay/common"
"strconv" "strconv"
"strings" "strings"
) )
@@ -23,10 +24,34 @@ type Token struct {
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"` ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"` ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Group string `json:"group" gorm:"default:''"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
} }
func (token *Token) GetIpLimitsMap() map[string]any {
// delete empty spaces
//split with \n
ipLimitsMap := make(map[string]any)
if token.AllowIps == nil {
return ipLimitsMap
}
cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
if cleanIps == "" {
return ipLimitsMap
}
ips := strings.Split(cleanIps, "\n")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
ip = strings.ReplaceAll(ip, ",", "")
if common.IsIP(ip) {
ipLimitsMap[ip] = true
}
}
return ipLimitsMap
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
var tokens []*Token var tokens []*Token
var err error var err error
@@ -130,7 +155,8 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values // Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error { func (token *Token) Update() error {
var err error var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
return err return err
} }
@@ -232,51 +258,56 @@ func decreaseTokenQuota(id int, quota int) (err error) {
return err return err
} }
func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) { func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) {
if quota < 0 { if quota < 0 {
return 0, errors.New("quota 不能为负数!") return 0, errors.New("quota 不能为负数!")
} }
token, err := GetTokenById(tokenId) if !relayInfo.IsPlayground {
if err != nil { token, err := GetTokenById(relayInfo.TokenId)
return 0, err if err != nil {
return 0, err
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
} }
if !token.UnlimitedQuota && token.RemainQuota < quota { userQuota, err = GetUserQuota(relayInfo.UserId)
return 0, errors.New("令牌额度不足")
}
userQuota, err = GetUserQuota(token.UserId)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if userQuota < quota { if userQuota < quota {
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
} }
err = DecreaseTokenQuota(tokenId, quota) if !relayInfo.IsPlayground {
if err != nil { err = DecreaseTokenQuota(relayInfo.TokenId, quota)
return 0, err if err != nil {
return 0, err
}
} }
err = DecreaseUserQuota(token.UserId, quota) err = DecreaseUserQuota(relayInfo.UserId, quota)
return userQuota - quota, err return userQuota - quota, err
} }
func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 { if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota) err = DecreaseUserQuota(relayInfo.UserId, quota)
} else { } else {
err = IncreaseUserQuota(token.UserId, -quota) err = IncreaseUserQuota(relayInfo.UserId, -quota)
} }
if err != nil { if err != nil {
return err return err
} }
if quota > 0 { if !relayInfo.IsPlayground {
err = DecreaseTokenQuota(tokenId, quota) if quota > 0 {
} else { err = DecreaseTokenQuota(relayInfo.TokenId, quota)
err = IncreaseTokenQuota(tokenId, -quota) } else {
} err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
if err != nil { }
return err if err != nil {
return err
}
} }
if sendEmail { if sendEmail {
@@ -285,7 +316,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0 noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota { if quotaTooLow || noMoreQuota {
go func() { go func() {
email, err := GetUserEmail(token.UserId) email, err := GetUserEmail(relayInfo.UserId)
if err != nil { if err != nil {
common.SysError("failed to fetch user email: " + err.Error()) common.SysError("failed to fetch user email: " + err.Error())
} }

View File

@@ -25,7 +25,7 @@ type User struct {
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"` Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
@@ -38,6 +38,17 @@ type User struct {
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
} }
func (user *User) GetAccessToken() string {
if user.AccessToken == nil {
return ""
}
return *user.AccessToken
}
func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) { func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User var user User
@@ -201,7 +212,7 @@ func (user *User) Insert(inviterId int) error {
} }
} }
user.Quota = common.QuotaForNewUser user.Quota = common.QuotaForNewUser
user.AccessToken = common.GetUUID() //user.SetAccessToken(common.GetUUID())
user.AffCode = common.GetRandomString(4) user.AffCode = common.GetRandomString(4)
result := DB.Create(user) result := DB.Create(user)
if result.Error != nil { if result.Error != nil {
@@ -295,11 +306,12 @@ func (user *User) ValidateAndFill() (err error) {
// that means if your fields value is 0, '', false or other zero values, // that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions // it wont be used to build query conditions
password := user.Password password := user.Password
if user.Username == "" || password == "" { username := strings.TrimSpace(user.Username)
if username == "" || password == "" {
return errors.New("用户名或密码为空") return errors.New("用户名或密码为空")
} }
// find buy username or email // find buy username or email
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user) DB.Where("username = ? OR email = ?", username, username).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password) okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled { if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁") return errors.New("用户名或密码错误,或用户已被封禁")
@@ -339,14 +351,6 @@ func (user *User) FillUserByWeChatId() error {
return nil return nil
} }
func (user *User) FillUserByUsername() error {
if user.Username == "" {
return errors.New("username 为空!")
}
DB.Where(User{Username: user.Username}).First(user)
return nil
}
func (user *User) FillUserByTelegramId() error { func (user *User) FillUserByTelegramId() error {
if user.TelegramId == "" { if user.TelegramId == "" {
return errors.New("Telegram id 为空!") return errors.New("Telegram id 为空!")
@@ -359,23 +363,19 @@ func (user *User) FillUserByTelegramId() error {
} }
func IsEmailAlreadyTaken(email string) bool { func IsEmailAlreadyTaken(email string) bool {
return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1 return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1
} }
func IsWeChatIdAlreadyTaken(wechatId string) bool { func IsWeChatIdAlreadyTaken(wechatId string) bool {
return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
} }
func IsGitHubIdAlreadyTaken(githubId string) bool { func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
} }
func IsTelegramIdAlreadyTaken(telegramId string) bool { func IsTelegramIdAlreadyTaken(telegramId string) bool {
return DB.Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1 return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
} }
func ResetUserPasswordByEmail(email string, password string) error { func ResetUserPasswordByEmail(email string, password string) error {

View File

@@ -12,13 +12,13 @@ type Adaptor interface {
// Init IsStream bool // Init IsStream bool
Init(info *relaycommon.RelayInfo) Init(info *relaycommon.RelayInfo)
GetRequestURL(info *relaycommon.RelayInfo) (string, error) GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string GetModelList() []string
GetChannelName() string GetChannelName() string
} }

View File

@@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fullRequestURL, nil return fullRequestURL, nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Authorization", "Bearer "+info.ApiKey)
if info.IsStream { if info.IsStream {
req.Header.Set("X-DashScope-SSE", "enable") req.Set("X-DashScope-SSE", "enable")
} }
if c.GetString("plugin") != "" { if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) req.Set("X-DashScope-Plugin", c.GetString("plugin"))
} }
return nil return nil
} }
@@ -72,11 +72,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info) err, usage = aliImageHandler(c, resp, info)

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io" "io"
"net/http" "net/http"
"one-api/relay/common" "one-api/relay/common"
@@ -11,14 +12,16 @@ import (
"one-api/service" "one-api/service"
) )
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) { func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
// multipart/form-data // multipart/form-data
} else if info.RelayMode == constant.RelayModeRealtime {
// websocket
} else { } else {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" { if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream") req.Set("Accept", "text/event-stream")
} }
} }
} }
@@ -32,7 +35,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil { if err != nil {
return nil, fmt.Errorf("new request failed: %w", err) return nil, fmt.Errorf("new request failed: %w", err)
} }
err = a.SetupRequestHeader(c, req, info) err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil { if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err) return nil, fmt.Errorf("setup request header failed: %w", err)
} }
@@ -55,7 +58,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
// set form data // set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
err = a.SetupRequestHeader(c, req, info) err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil { if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err) return nil, fmt.Errorf("setup request header failed: %w", err)
} }
@@ -66,6 +69,27 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
return resp, nil return resp, nil
} }
func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
targetHeader := http.Header{}
err = a.SetupRequestHeader(c, &targetHeader, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
if err != nil {
return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
}
// send request body
//all, err := io.ReadAll(requestBody)
//err = service.WssString(c, targetConn, string(all))
return targetConn, nil
}
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := service.GetHttpClient().Do(req) resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {

View File

@@ -8,7 +8,6 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"strings"
) )
const ( const (
@@ -31,18 +30,14 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") { a.RequestMode = RequestModeMessage
a.RequestMode = RequestModeMessage
} else {
a.RequestMode = RequestModeCompletion
}
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil return "", nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
return nil return nil
} }
@@ -53,11 +48,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
var claudeReq *claude.ClaudeRequest var claudeReq *claude.ClaudeRequest
var err error var err error
if a.RequestMode == RequestModeCompletion { claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
claudeReq = claude.RequestOpenAI2ClaudeComplete(*request)
} else {
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
}
c.Set("request_model", request.Model) c.Set("request_model", request.Model)
c.Set("converted_request", claudeReq) c.Set("converted_request", claudeReq)
return claudeReq, err return claudeReq, err
@@ -67,11 +59,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil return nil, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode) err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else { } else {

View File

@@ -1,13 +1,14 @@
package aws package aws
var awsModelIDMap = map[string]string{ var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1", "claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2", "claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1", "claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
} }
var ChannelName = "aws" var ChannelName = "aws"

View File

@@ -1,6 +1,8 @@
package aws package aws
import "one-api/relay/channel/claude" import (
"one-api/relay/channel/claude"
)
type AwsClaudeRequest struct { type AwsClaudeRequest struct {
// AnthropicVersion should be "bedrock-2023-05-31" // AnthropicVersion should be "bedrock-2023-05-31"
@@ -12,4 +14,6 @@ type AwsClaudeRequest struct {
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
} }

View File

@@ -53,7 +53,7 @@ func awsModelID(requestModel string) (string, error) {
return awsModelID, nil return awsModelID, nil
} }
return "", errors.Errorf("model %s not found", requestModel) return requestModel, nil
} }
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {

View File

@@ -98,9 +98,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fullRequestURL, nil return fullRequestURL, nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Authorization", "Bearer "+info.ApiKey)
return nil return nil
} }
@@ -122,11 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = baiduStreamHandler(c, resp) err, usage = baiduStreamHandler(c, resp)
} else { } else {

View File

@@ -47,14 +47,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-api-key", info.ApiKey) req.Set("x-api-key", info.ApiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version") anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" { if anthropicVersion == "" {
anthropicVersion = "2023-06-01" anthropicVersion = "2023-06-01"
} }
req.Header.Set("anthropic-version", anthropicVersion) req.Set("anthropic-version", anthropicVersion)
return nil return nil
} }
@@ -73,15 +73,15 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode) err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else { } else {
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
} }
return return
} }

View File

@@ -8,7 +8,9 @@ var ModelList = []string{
"claude-3-sonnet-20240229", "claude-3-sonnet-20240229",
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
} }
var ChannelName = "claude" var ChannelName = "claude"

View File

@@ -31,9 +31,9 @@ type ClaudeMessage struct {
} }
type Tool struct { type Tool struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
InputSchema InputSchema `json:"input_schema"` InputSchema map[string]interface{} `json:"input_schema"`
} }
type InputSchema struct { type InputSchema struct {

View File

@@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -12,6 +11,8 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
func stopReasonClaude2OpenAI(reason string) string { func stopReasonClaude2OpenAI(reason string) string {
@@ -63,15 +64,21 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok { if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTools = append(claudeTools, Tool{ claudeTool := Tool{
Name: tool.Function.Name, Name: tool.Function.Name,
Description: tool.Function.Description, Description: tool.Function.Description,
InputSchema: InputSchema{ }
Type: params["type"].(string), claudeTool.InputSchema = make(map[string]interface{})
Properties: params["properties"], claudeTool.InputSchema["type"] = params["type"].(string)
Required: params["required"], claudeTool.InputSchema["properties"] = params["properties"]
}, claudeTool.InputSchema["required"] = params["required"]
}) for s, a := range params {
if s == "type" || s == "properties" || s == "required" {
continue
}
claudeTool.InputSchema[s] = a
}
claudeTools = append(claudeTools, claudeTool)
} }
} }
@@ -102,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
} }
} }
formatMessages := make([]dto.Message, 0) formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message lastMessage := dto.Message{
Role: "tool",
}
for i, message := range textRequest.Messages { for i, message := range textRequest.Messages {
//if message.Role == "system" {
// if i != 0 {
// message.Role = "user"
// }
//}
if message.Role == "" { if message.Role == "" {
textRequest.Messages[i].Role = "user" textRequest.Messages[i].Role = "user"
} }
@@ -116,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.Content,
} }
if lastMessage != nil && lastMessage.Role == message.Role { if message.Role == "tool" {
fmtMessage.ToolCallId = message.ToolCallId
}
if message.Role == "assistant" && message.ToolCalls != nil {
fmtMessage.ToolCalls = message.ToolCalls
}
if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
if lastMessage.IsStringContent() && message.IsStringContent() { if lastMessage.IsStringContent() && message.IsStringContent() {
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\"")) content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
fmtMessage.Content = content fmtMessage.Content = content
@@ -129,10 +139,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
fmtMessage.Content = content fmtMessage.Content = content
} }
formatMessages = append(formatMessages, fmtMessage) formatMessages = append(formatMessages, fmtMessage)
lastMessage = &textRequest.Messages[i] lastMessage = fmtMessage
} }
claudeMessages := make([]ClaudeMessage, 0) claudeMessages := make([]ClaudeMessage, 0)
isFirstMessage := true
for _, message := range formatMessages { for _, message := range formatMessages {
if message.Role == "system" { if message.Role == "system" {
if message.IsStringContent() { if message.IsStringContent() {
@@ -148,10 +159,54 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
claudeRequest.System = content claudeRequest.System = content
} }
} else { } else {
if isFirstMessage {
isFirstMessage = false
if message.Role != "user" {
// fix: first message is assistant, add user message
claudeMessage := ClaudeMessage{
Role: "user",
Content: []ClaudeMediaMessage{
{
Type: "text",
Text: "...",
},
},
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeMessage := ClaudeMessage{ claudeMessage := ClaudeMessage{
Role: message.Role, Role: message.Role,
} }
if message.IsStringContent() { if message.Role == "tool" {
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
lastMessage := claudeMessages[len(claudeMessages)-1]
if content, ok := lastMessage.Content.(string); ok {
lastMessage.Content = []ClaudeMediaMessage{
{
Type: "text",
Text: content,
},
}
}
lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
})
claudeMessages[len(claudeMessages)-1] = lastMessage
continue
} else {
claudeMessage.Role = "user"
claudeMessage.Content = []ClaudeMediaMessage{
{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
},
}
}
} else if message.IsStringContent() && message.ToolCalls == nil {
claudeMessage.Content = message.StringContent() claudeMessage.Content = message.StringContent()
} else { } else {
claudeMediaMessages := make([]ClaudeMediaMessage, 0) claudeMediaMessages := make([]ClaudeMediaMessage, 0)
@@ -184,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
} }
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
} }
if message.ToolCalls != nil {
for _, tc := range message.ToolCalls.([]interface{}) {
toolCallJSON, _ := json.Marshal(tc)
var toolCall dto.ToolCall
err := json.Unmarshal(toolCallJSON, &toolCall)
if err != nil {
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
continue
}
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
Type: "tool_use",
Id: toolCall.ID,
Name: toolCall.Function.Name,
Input: inputObj,
})
}
}
claudeMessage.Content = claudeMediaMessages claudeMessage.Content = claudeMediaMessages
} }
claudeMessages = append(claudeMessages, claudeMessage) claudeMessages = append(claudeMessages, claudeMessage)
@@ -318,12 +395,13 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(tools) > 0 { if len(tools) > 0 {
choice.Message.ToolCalls = tools choice.Message.ToolCalls = tools
} }
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice) choices = append(choices, choice)
fullTextResponse.Choices = choices fullTextResponse.Choices = choices
return &fullTextResponse return &fullTextResponse
} }
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage var usage *dto.Usage
usage = &dto.Usage{} usage = &dto.Usage{}
@@ -405,7 +483,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
return nil, usage return nil, usage
} }
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -431,15 +509,15 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
}, nil }, nil
} }
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model) completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
} }
usage := dto.Usage{} usage := dto.Usage{}
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
usage.PromptTokens = promptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = completionTokens usage.CompletionTokens = completionTokens
usage.TotalTokens = promptTokens + completionTokens usage.TotalTokens = info.PromptTokens + completionTokens
} else { } else {
usage.PromptTokens = claudeResponse.Usage.InputTokens usage.PromptTokens = claudeResponse.Usage.InputTokens
usage.CompletionTokens = claudeResponse.Usage.OutputTokens usage.CompletionTokens = claudeResponse.Usage.OutputTokens

View File

@@ -30,9 +30,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil return nil
} }
@@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
} }
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
fallthrough fallthrough

View File

@@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName) usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage return nil, usage

View File

@@ -36,9 +36,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil return nil
} }
@@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return requestOpenAI2Cohere(*request), nil return requestOpenAI2Cohere(*request), nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
@@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return requestConvertRerank2Cohere(request), nil return requestConvertRerank2Cohere(request), nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info) err, usage = cohereRerankHandler(c, resp, info)
} else { } else {

View File

@@ -1,7 +1,10 @@
package cohere package cohere
var ModelList = []string{ var ModelList = []string{
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly", "command-r", "command-r-plus",
"command-r-08-2024", "command-r-plus-08-2024",
"c4ai-aya-23-35b", "c4ai-aya-23-8b",
"command-light", "command-light-nightly", "command", "command-nightly",
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0", "rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
} }

View File

@@ -8,6 +8,7 @@ type CohereRequest struct {
Message string `json:"message"` Message string `json:"message"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
SafetyMode string `json:"safety_mode,omitempty"`
} }
type ChatHistory struct { type ChatHistory struct {

View File

@@ -23,6 +23,9 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
Stream: textRequest.Stream, Stream: textRequest.Stream,
MaxTokens: textRequest.GetMaxTokens(), MaxTokens: textRequest.GetMaxTokens(),
} }
if common.CohereSafetySetting != "NONE" {
cohereReq.SafetyMode = common.CohereSafetySetting
}
if cohereReq.MaxTokens == 0 { if cohereReq.MaxTokens == 0 {
cohereReq.MaxTokens = 4000 cohereReq.MaxTokens = 4000
} }
@@ -44,6 +47,7 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
}) })
} }
} }
return &cohereReq return &cohereReq
} }

View File

@@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Authorization", "Bearer "+info.ApiKey)
return nil return nil
} }
@@ -48,11 +48,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = difyStreamHandler(c, resp, info) err, usage = difyStreamHandler(c, resp, info)
} else { } else {

View File

@@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
} }
if usage.TotalTokens == 0 { if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText) usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
} }
return nil, usage return nil, usage

View File

@@ -47,9 +47,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey) req.Set("x-goog-api-key", info.ApiKey)
return nil return nil
} }
@@ -64,15 +64,15 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = geminiChatStreamHandler(c, resp, info) err, usage = GeminiChatStreamHandler(c, resp, info)
} else { } else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = GeminiChatHandler(c, resp)
} }
return return
} }

View File

@@ -6,7 +6,7 @@ const (
var ModelList = []string{ var ModelList = []string{
"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra", "gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827",
} }
var ChannelName = "google gemini" var ChannelName = "google gemini"

View File

@@ -220,7 +220,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response return &response
} }
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := "" responseText := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp() createAt := common.GetTimestamp()
@@ -279,7 +279,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
return nil, usage return nil, usage
} }
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings { } else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} }
return "", errors.New("invalid relay mode") return "", errors.New("invalid relay mode")
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil return nil
} }
@@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
@@ -55,9 +55,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil return request, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp) err, usage = jinaRerankHandler(c, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = jinaEmbeddingHandler(c, resp)
} }
return return
} }

View File

@@ -33,3 +33,28 @@ func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage return nil, &jinaResp.Usage
} }
func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var jinaResp dto.OpenAIEmbeddingResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}

View File

@@ -0,0 +1,72 @@
package mistral
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
mistralReq := requestOpenAI2Mistral(*request)
//common.LogJson(c, "body", mistralReq)
return mistralReq, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,12 @@
package mistral
var ModelList = []string{
"open-mistral-7b",
"open-mixtral-8x7b",
"mistral-small-latest",
"mistral-medium-latest",
"mistral-large-latest",
"mistral-embed",
}
var ChannelName = "mistral"

View File

@@ -0,0 +1,40 @@
package mistral
import (
"encoding/json"
"one-api/dto"
)
func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage
}
}
messageRaw, _ := json.Marshal(mediaMessages)
message.Content = messageRaw
}
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
return &dto.GeneralOpenAIRequest{
Model: request.Model,
Stream: request.Stream,
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
Tools: request.Tools,
ToolChoice: request.ToolChoice,
}
}

View File

@@ -31,13 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode { switch info.RelayMode {
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
return info.BaseUrl + "/api/embeddings", nil return info.BaseUrl + "/api/embed", nil
default: default:
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
return nil return nil
} }
@@ -58,11 +58,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {

View File

@@ -3,25 +3,39 @@ package ollama
import "one-api/dto" import "one-api/dto"
type OllamaRequest struct { type OllamaRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"` Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"` Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"` Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"` Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"` Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat *dto.ResponseFormat `json:"response_format,omitempty"` ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
} }
type OllamaEmbeddingRequest struct { type OllamaEmbeddingRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Prompt any `json:"prompt,omitempty"` Input []string `json:"input"`
Options *Options `json:"options,omitempty"`
} }
type OllamaEmbeddingResponse struct { type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding []float64 `json:"embedding,omitempty"` Embedding []float64 `json:"embedding,omitempty"`
} }

View File

@@ -9,7 +9,6 @@ import (
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/service" "one-api/service"
"strings"
) )
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
@@ -45,8 +44,15 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{ return &OllamaEmbeddingRequest{
Model: request.Model, Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "), Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
} }
} }
@@ -64,6 +70,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if ollamaEmbeddingResponse.Error != "" {
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
}
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{ data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: ollamaEmbeddingResponse.Embedding, Embedding: ollamaEmbeddingResponse.Embedding,

View File

@@ -31,6 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRealtime {
// trim https
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
baseUrl = strings.TrimPrefix(baseUrl, "http://")
baseUrl = "wss://" + baseUrl
info.BaseUrl = baseUrl
}
switch info.ChannelType { switch info.ChannelType {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
@@ -40,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
model_ := info.UpstreamModelName model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1) model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67 // https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax: case common.ChannelTypeMiniMax:
return minimax.GetRequestURL(info) return minimax.GetRequestURL(info)
@@ -54,16 +63,34 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, header)
if info.ChannelType == common.ChannelTypeAzure { if info.ChannelType == common.ChannelTypeAzure {
req.Header.Set("api-key", info.ApiKey) header.Set("api-key", info.ApiKey)
return nil return nil
} }
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization { if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
req.Header.Set("OpenAI-Organization", info.Organization) header.Set("OpenAI-Organization", info.Organization)
}
if info.RelayMode == constant.RelayModeRealtime {
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
if swp != "" {
items := []string{
"realtime",
"openai-insecure-api-key." + info.ApiKey,
"openai-beta.realtime-v1",
}
header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
} else {
header.Set("openai-beta", "realtime=v1")
header.Set("Authorization", "Bearer "+info.ApiKey)
}
} else {
header.Set("Authorization", "Bearer "+info.ApiKey)
} }
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
//if info.ChannelType == common.ChannelTypeOpenRouter { //if info.ChannelType == common.ChannelTypeOpenRouter {
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") // req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
// req.Header.Set("X-Title", "One API") // req.Header.Set("X-Title", "One API")
@@ -78,6 +105,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if info.ChannelType != common.ChannelTypeOpenAI { if info.ChannelType != common.ChannelTypeOpenAI {
request.StreamOptions = nil request.StreamOptions = nil
} }
if strings.HasPrefix(request.Model, "o1-") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
}
return request, nil return request, nil
} }
@@ -125,16 +158,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
return channel.DoFormRequest(a, c, info, requestBody) return channel.DoFormRequest(a, c, info, requestBody)
} else if info.RelayMode == constant.RelayModeRealtime {
return channel.DoWssRequest(a, c, info, requestBody)
} else { } else {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeRealtime:
err, usage = OpenaiRealtimeHandler(c, info)
case constant.RelayModeAudioSpeech: case constant.RelayModeAudioSpeech:
err, usage = OpenaiTTSHandler(c, resp, info) err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeAudioTranslation: case constant.RelayModeAudioTranslation:

View File

@@ -8,8 +8,13 @@ var ModelList = []string{
"gpt-4-32k", "gpt-4-32k-0613", "gpt-4-32k", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"chatgpt-4o-latest",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18", "gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-curie-001", "text-babbage-001", "text-ada-001",
"text-moderation-latest", "text-moderation-stable", "text-moderation-latest", "text-moderation-stable",

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool" "github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -231,7 +232,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range simpleResponse.Choices { for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model) ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
completionTokens += ctkm completionTokens += ctkm
} }
simpleResponse.Usage = dto.Usage{ simpleResponse.Usage = dto.Usage{
@@ -324,7 +325,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage return nil, usage
} }
@@ -373,3 +374,210 @@ func getTextFromJSON(body []byte) (string, error) {
} }
return whisperResponse.Text, nil return whisperResponse.Text, nil
} }
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
info.IsStream = true
clientConn := info.ClientWs
targetConn := info.TargetWs
clientClosed := make(chan struct{})
targetClosed := make(chan struct{})
sendChan := make(chan []byte, 100)
receiveChan := make(chan []byte, 100)
errChan := make(chan error, 2)
usage := &dto.RealtimeUsage{}
localUsage := &dto.RealtimeUsage{}
sumUsage := &dto.RealtimeUsage{}
gopool.Go(func() {
for {
select {
case <-c.Done():
return
default:
_, message, err := clientConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from client: %v", err)
}
close(clientClosed)
return
}
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
if realtimeEvent.Session != nil {
if realtimeEvent.Session.Tools != nil {
info.RealtimeTools = realtimeEvent.Session.Tools
}
}
}
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = service.WssString(c, targetConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err)
return
}
select {
case sendChan <- message:
default:
}
}
}
})
gopool.Go(func() {
for {
select {
case <-c.Done():
return
default:
_, message, err := targetConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from target: %v", err)
}
close(targetClosed)
return
}
info.SetFirstResponseTime()
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
realtimeUsage := realtimeEvent.Response.Usage
if realtimeUsage != nil {
usage.TotalTokens += realtimeUsage.TotalTokens
usage.InputTokens += realtimeUsage.InputTokens
usage.OutputTokens += realtimeUsage.OutputTokens
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
err := preConsumeUsage(c, info, usage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
usage = &dto.RealtimeUsage{}
localUsage = &dto.RealtimeUsage{}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = preConsumeUsage(c, info, localUsage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
localUsage = &dto.RealtimeUsage{}
// print now usage
}
//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
if realtimeSession != nil {
// update audio format
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
err = service.WssString(c, clientConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to client: %v", err)
return
}
select {
case receiveChan <- message:
default:
}
}
}
})
select {
case <-clientClosed:
case <-targetClosed:
case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
common.LogError(c, "realtime error: "+err.Error())
case <-c.Done():
}
if usage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, usage, sumUsage)
}
if localUsage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, localUsage, sumUsage)
}
// check usage total tokens, if 0, use local usage
return nil, sumUsage
}
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
totalUsage.TotalTokens += usage.TotalTokens
totalUsage.InputTokens += usage.InputTokens
totalUsage.OutputTokens += usage.OutputTokens
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
// clear usage
err := service.PreWssConsumeQuota(ctx, info, usage)
return err
}

View File

@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-goog-api-key", info.ApiKey) req.Set("x-goog-api-key", info.ApiKey)
return nil return nil
} }
@@ -49,11 +49,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = palmStreamHandler(c, resp) err, responseText = palmStreamHandler(c, resp)

View File

@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil }, nil
} }
fullTextResponse := responsePaLM2OpenAI(&palmResponse) fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model) completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{ usage := dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,

View File

@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Authorization", "Bearer "+info.ApiKey)
return nil return nil
} }
@@ -52,11 +52,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {

View File

@@ -0,0 +1,83 @@
package siliconflow
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeRerank:
err, usage = siliconflowRerankHandler(c, resp)
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,51 @@
package siliconflow
var ModelList = []string{
"THUDM/glm-4-9b-chat",
//"stabilityai/stable-diffusion-xl-base-1.0",
//"TencentARC/PhotoMaker",
"InstantX/InstantID",
//"stabilityai/stable-diffusion-2-1",
//"stabilityai/sd-turbo",
//"stabilityai/sdxl-turbo",
"ByteDance/SDXL-Lightning",
"deepseek-ai/deepseek-llm-67b-chat",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"deepseek-ai/DeepSeek-V2-Chat",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
//"stabilityai/stable-diffusion-3-medium",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-large-zh-v1.5",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/Qwen/Qwen1.5-7B-Chat",
"Pro/THUDM/glm-4-9b-chat",
"Pro/THUDM/chatglm3-6b",
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
"Pro/01-ai/Yi-1.5-6B-Chat",
"Pro/google/gemma-2-9b-it",
"Pro/internlm/internlm2_5-7b-chat",
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
"black-forest-labs/FLUX.1-schnell",
"iic/SenseVoiceSmall",
"netease-youdao/bce-embedding-base_v1",
"BAAI/bge-m3",
"internlm/internlm2_5-20b-chat",
"Qwen/Qwen2-Math-72B-Instruct",
"netease-youdao/bce-reranker-base_v1",
"BAAI/bge-reranker-v2-m3",
}
var ChannelName = "siliconflow"

View File

@@ -0,0 +1,17 @@
package siliconflow
import "one-api/dto"
type SFTokens struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type SFMeta struct {
Tokens SFTokens `json:"tokens"`
}
type SFRerankResponse struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta SFMeta `json:"meta"`
}

View File

@@ -0,0 +1,44 @@
package siliconflow
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
CompletionTokens: siliconflowResp.Meta.Tokens.OutputTokens,
TotalTokens: siliconflowResp.Meta.Tokens.InputTokens + siliconflowResp.Meta.Tokens.OutputTokens,
}
rerankResp := &dto.RerankResponse{
Results: siliconflowResp.Results,
Usage: *usage,
}
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -43,12 +43,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/", info.BaseUrl), nil return fmt.Sprintf("%s/", info.BaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", a.Sign) req.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", a.Action) req.Set("X-TC-Action", a.Action)
req.Header.Set("X-TC-Version", a.Version) req.Set("X-TC-Version", a.Version)
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10)) req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
return nil return nil
} }
@@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = tencentStreamHandler(c, resp) err, responseText = tencentStreamHandler(c, resp)

View File

@@ -0,0 +1,184 @@
package vertex
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/claude"
"one-api/relay/channel/gemini"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"strings"
)
const (
RequestModeClaude = 1
RequestModeGemini = 2
RequestModeLlama = 3
)
var claudeModelMap = map[string]string{
"claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
"claude-3-opus-20240229": "claude-3-opus@20240229",
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
}
const anthropicVersion = "vertex-2023-10-16"
type Adaptor struct {
RequestMode int
AccountCredentials Credentials
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
a.RequestMode = RequestModeGemini
} else if strings.Contains(info.UpstreamModelName, "llama") {
a.RequestMode = RequestModeLlama
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
adc := &Credentials{}
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials file: %w", err)
}
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
if info.IsStream {
suffix = "streamGenerateContent?alt=sse"
} else {
suffix = "generateContent"
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
region,
adc.ProjectID,
region,
info.UpstreamModelName,
suffix,
), nil
} else if a.RequestMode == RequestModeClaude {
if info.IsStream {
suffix = "streamRawPredict?alt=sse"
} else {
suffix = "rawPredict"
}
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
info.UpstreamModelName = v
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
info.UpstreamModelName,
suffix,
), nil
} else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
region,
adc.ProjectID,
region,
), nil
}
return "", errors.New("unsupported request mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
accessToken, err := getAccessToken(a, info)
if err != nil {
return err
}
req.Set("Authorization", "Bearer "+accessToken)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeClaude {
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil {
return nil, err
}
vertexClaudeReq := &VertexAIClaudeRequest{
AnthropicVersion: anthropicVersion,
}
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
return nil, errors.New("failed to copy claude request")
}
c.Set("request_model", request.Model)
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
geminiRequest := gemini.CovertGemini2OpenAI(*request)
c.Set("request_model", request.Model)
return geminiRequest, nil
} else if a.RequestMode == RequestModeLlama {
return request, nil
}
return nil, errors.New("unsupported request mode")
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
case RequestModeLlama:
err, usage = openai.OaiStreamHandler(c, resp, info)
}
} else {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
case RequestModeGemini:
err, usage = gemini.GeminiChatHandler(c, resp)
case RequestModeLlama:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,15 @@
package vertex
var ModelList = []string{
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
//"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
"meta/llama3-405b-instruct-maas",
}
var ChannelName = "vertex-ai"

View File

@@ -0,0 +1,17 @@
package vertex
import "one-api/relay/channel/claude"
type VertexAIClaudeRequest struct {
AnthropicVersion string `json:"anthropic_version"`
Messages []claude.ClaudeMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}

View File

@@ -0,0 +1,16 @@
package vertex
import "one-api/common"
func GetModelRegion(other string, localModelName string) string {
// if other is json string
if common.IsJsonStr(other) {
m := common.StrToMap(other)
if m[localModelName] != nil {
return m[localModelName].(string)
} else {
return m["default"].(string)
}
}
return other
}

View File

@@ -0,0 +1,122 @@
package vertex
import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"github.com/bytedance/gopkg/cache/asynccache"
"github.com/golang-jwt/jwt"
"net/http"
"net/url"
relaycommon "one-api/relay/common"
"strings"
"fmt"
"time"
)
type Credentials struct {
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
ClientID string `json:"client_id"`
}
var Cache = asynccache.NewAsyncCache(asynccache.Options{
RefreshDuration: time.Minute * 35,
EnableExpire: true,
ExpireDuration: time.Minute * 30,
Fetcher: func(key string) (interface{}, error) {
return nil, errors.New("not found")
},
})
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
val, err := Cache.Get(cacheKey)
if err == nil {
return val.(string), nil
}
signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
if err != nil {
return "", fmt.Errorf("failed to create signed JWT: %w", err)
}
newToken, err := exchangeJwtForAccessToken(signedJWT)
if err != nil {
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
}
if err := Cache.SetDefault(cacheKey, newToken); err {
return newToken, nil
}
return newToken, nil
}
func createSignedJWT(email, privateKeyPEM string) (string, error) {
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
if block == nil {
return "", fmt.Errorf("failed to parse PEM block containing the private key")
}
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", err
}
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return "", fmt.Errorf("not an RSA private key")
}
now := time.Now()
claims := jwt.MapClaims{
"iss": email,
"scope": "https://www.googleapis.com/auth/cloud-platform",
"aud": "https://www.googleapis.com/oauth2/v4/token",
"exp": now.Add(time.Minute * 35).Unix(),
"iat": now.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedToken, err := token.SignedString(rsaPrivateKey)
if err != nil {
return "", err
}
return signedToken, nil
}
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
data := url.Values{}
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
data.Set("assertion", signedJWT)
resp, err := http.PostForm(authURL, data)
if err != nil {
return "", err
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
if accessToken, ok := result["access_token"].(string); ok {
return accessToken, nil
}
return "", fmt.Errorf("failed to get access token: %v", result)
}

View File

@@ -33,7 +33,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil return "", nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
return nil return nil
} }
@@ -50,14 +50,14 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// xunfei's request is not http request, so we don't need to do anything here // xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{} dummyResp := &http.Response{}
dummyResp.StatusCode = http.StatusOK dummyResp.StatusCode = http.StatusOK
return dummyResp, nil return dummyResp, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
splits := strings.Split(info.ApiKey, "|") splits := strings.Split(info.ApiKey, "|")
if len(splits) != 3 { if len(splits) != 3 {
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)

View File

@@ -35,10 +35,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
token := getZhipuToken(info.ApiKey) token := getZhipuToken(info.ApiKey)
req.Header.Set("Authorization", token) req.Set("Authorization", token)
return nil return nil
} }
@@ -56,11 +56,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = zhipuStreamHandler(c, resp) err, usage = zhipuStreamHandler(c, resp)
} else { } else {

View File

@@ -32,10 +32,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
token := getZhipuToken(info.ApiKey) token := getZhipuToken(info.ApiKey)
req.Header.Set("Authorization", token) req.Set("Authorization", token)
return nil return nil
} }
@@ -53,11 +53,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = openai.OaiStreamHandler(c, resp, info)
} else { } else {

View File

@@ -1,7 +1,7 @@
package zhipu_4v package zhipu_4v
var ModelList = []string{ var ModelList = []string{
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", "glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", "glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flash", "glm-4v-plus",
} }
var ChannelName = "zhipu_4v" var ChannelName = "zhipu_4v"

View File

@@ -2,7 +2,9 @@ package common
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"one-api/common" "one-api/common"
"one-api/dto"
"one-api/relay/constant" "one-api/relay/constant"
"strings" "strings"
"time" "time"
@@ -20,8 +22,11 @@ type RelayInfo struct {
setFirstResponse bool setFirstResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
IsPlayground bool
UsePrice bool
RelayMode int RelayMode int
UpstreamModelName string UpstreamModelName string
OriginModelName string
RequestURLPath string RequestURLPath string
ApiVersion string ApiVersion string
PromptTokens int PromptTokens int
@@ -30,6 +35,21 @@ type RelayInfo struct {
BaseUrl string BaseUrl string
SupportStreamOptions bool SupportStreamOptions bool
ShouldIncludeUsage bool ShouldIncludeUsage bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn
InputAudioFormat string
OutputAudioFormat string
RealtimeTools []dto.RealTimeTool
IsFirstRequest bool
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
info := GenRelayInfo(c)
info.ClientWs = ws
info.InputAudioFormat = "pcm16"
info.OutputAudioFormat = "pcm16"
info.IsFirstRequest = true
return info
} }
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {
@@ -57,17 +77,27 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
TokenUnlimited: tokenUnlimited, TokenUnlimited: tokenUnlimited,
StartTime: startTime, StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second), FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"),
ApiType: apiType, ApiType: apiType,
ApiVersion: c.GetString("api_version"), ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"), Organization: c.GetString("channel_organization"),
} }
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
info.RequestURLPath = "/v1" + info.RequestURLPath
}
if info.BaseUrl == "" { if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType] info.BaseUrl = common.ChannelBaseURLs[channelType]
} }
if info.ChannelType == common.ChannelTypeAzure { if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c) info.ApiVersion = GetAPIVersion(c)
} }
if info.ChannelType == common.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region")
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare { info.ChannelType == common.ChannelCloudflare {
@@ -140,3 +170,20 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
} }
return info return info
} }
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
return &RelayInfo{
ChannelType: info.ChannelType,
ChannelId: info.ChannelId,
TokenId: info.TokenId,
UserId: info.UserId,
Group: info.Group,
StartTime: info.StartTime,
ApiType: info.ApiType,
RelayMode: info.RelayMode,
UpstreamModelName: info.UpstreamModelName,
RequestURLPath: info.RequestURLPath,
ApiKey: info.ApiKey,
BaseUrl: info.BaseUrl,
}
}

View File

@@ -23,6 +23,9 @@ const (
APITypeDify APITypeDify
APITypeJina APITypeJina
APITypeCloudflare APITypeCloudflare
APITypeSiliconFlow
APITypeVertexAi
APITypeMistral
APITypeDummy // this one is only for count, do not add any channel after this APITypeDummy // this one is only for count, do not add any channel after this
) )
@@ -66,6 +69,12 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeJina apiType = APITypeJina
case common.ChannelCloudflare: case common.ChannelCloudflare:
apiType = APITypeCloudflare apiType = APITypeCloudflare
case common.ChannelTypeSiliconFlow:
apiType = APITypeSiliconFlow
case common.ChannelTypeVertexAi:
apiType = APITypeVertexAi
case common.ChannelTypeMistral:
apiType = APITypeMistral
} }
if apiType == -1 { if apiType == -1 {
return APITypeOpenAI, false return APITypeOpenAI, false

View File

@@ -38,11 +38,13 @@ const (
RelayModeSunoSubmit RelayModeSunoSubmit
RelayModeRerank RelayModeRerank
RelayModeRealtime
) )
func Path2RelayMode(path string) int { func Path2RelayMode(path string) int {
relayMode := RelayModeUnknown relayMode := RelayModeUnknown
if strings.HasPrefix(path, "/v1/chat/completions") { if strings.HasPrefix(path, "/v1/chat/completions") || strings.HasPrefix(path, "/pg/chat/completions") {
relayMode = RelayModeChatCompletions relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") { } else if strings.HasPrefix(path, "/v1/completions") {
relayMode = RelayModeCompletions relayMode = RelayModeCompletions
@@ -64,6 +66,8 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeAudioTranslation relayMode = RelayModeAudioTranslation
} else if strings.HasPrefix(path, "/v1/rerank") { } else if strings.HasPrefix(path, "/v1/rerank") {
relayMode = RelayModeRerank relayMode = RelayModeRerank
} else if strings.HasPrefix(path, "/v1/realtime") {
relayMode = RelayModeRealtime
} }
return relayMode return relayMode
} }

View File

@@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
return audioRequest, nil return audioRequest, nil
} }
func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfo(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo) audioRequest, err := getAndValidAudioRequest(c, relayInfo)
@@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
promptTokens := 0 promptTokens := 0
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model) promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
} }
@@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
} }
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil { if err != nil {
@@ -87,11 +87,16 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
preConsumedQuota = 0 preConsumedQuota = 0
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota) userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
} }
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
@@ -122,27 +127,28 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil { if resp != nil {
if resp.StatusCode != http.StatusOK { httpResp = resp.(*http.Response)
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) if httpResp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(resp) openaiErr = service.RelayErrorHandler(httpResp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
} }
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
return nil return nil
} }

View File

@@ -38,9 +38,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.Model == "" { if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2" imageRequest.Model = "dall-e-2"
} }
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
// Not "256x256", "512x512", or "1024x1024" // Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
@@ -50,6 +48,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
} }
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
//if imageRequest.N != 1 { //if imageRequest.N != 1 {
// return nil, errors.New("n must be 1") // return nil, errors.New("n must be 1")
//} //}
@@ -125,7 +126,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
quota := int(imageRatio * groupRatio * common.QuotaPerUnit) quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
if userQuota-quota < 0 { if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("image pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, quota)), "insufficient_user_quota", http.StatusBadRequest)
} }
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)
@@ -148,22 +149,24 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
requestBody = bytes.NewBuffer(jsonData) requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
var httpResp *http.Response
if resp != nil { if resp != nil {
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") httpResp = resp.(*http.Response)
if resp.StatusCode != http.StatusOK { relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
openaiErr := service.RelayErrorHandler(resp) if httpResp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(httpResp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
} }
_, openaiErr := adaptor.DoResponse(c, resp, relayInfo) _, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)

View File

@@ -12,6 +12,7 @@ import (
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strconv" "strconv"
@@ -146,6 +147,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
var swapFaceRequest dto.SwapFaceRequest var swapFaceRequest dto.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest) err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil { if err != nil {
@@ -191,7 +193,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
} }
defer func(ctx context.Context) { defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }
@@ -356,6 +358,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
consumeQuota := true consumeQuota := true
var midjRequest dto.MidjourneyRequest var midjRequest dto.MidjourneyRequest
err := common.UnmarshalBodyReusable(c, &midjRequest) err := common.UnmarshalBodyReusable(c, &midjRequest)
@@ -495,7 +498,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 { if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }

Some files were not shown because too many files have changed in this diff Show More