Compare commits

..

152 Commits

Author SHA1 Message Date
wozulong
142d5fb209 merge upstream
Signed-off-by: wozulong <>
2024-08-07 14:41:33 +08:00
wozulong
f6ccd402e2 support gpt-4o-2024-08-06
Signed-off-by: wozulong <>
2024-08-07 14:40:57 +08:00
CalciumIon
4490258104 fix bug 2024-08-07 02:50:22 +08:00
CalciumIon
93c6d765c7 feat: support gpt-4o-2024-08-06 2024-08-07 02:49:02 +08:00
wozulong
1c371300ab merge upstream
Signed-off-by: wozulong <>
2024-08-06 15:54:28 +08:00
CalciumIon
67878731fc feat: log user id 2024-08-04 14:35:16 +08:00
CalciumIon
a0a3807bd4 chore: epay 2024-08-04 03:12:24 +08:00
CalciumIon
5d0d268c97 fix: epay 2024-08-04 00:18:32 +08:00
CalciumIon
0b4ef42d86 fix: channel typ error 2024-08-03 22:41:47 +08:00
CalciumIon
0123ad4d61 fix: 重试后request id不一致 2024-08-03 17:46:13 +08:00
CalciumIon
5acf074541 chore: 优化自动禁用代码 2024-08-03 17:32:28 +08:00
Calcium-Ion
8af0d9f22f Merge pull request #409 from utopeadia/main
修改readme错误
2024-08-03 17:31:08 +08:00
HowieWu
afd328efcf 修改readme错误 2024-08-03 17:19:44 +08:00
CalciumIon
dd12a0052f chore: 优化relay代码 2024-08-03 17:12:16 +08:00
CalciumIon
fbe6cd75b1 chore: 优化relay代码 2024-08-03 17:07:14 +08:00
CalciumIon
8a9ff36fbf chore: 优化relay代码 2024-08-03 16:55:29 +08:00
CalciumIon
88ba8a840e feat: 优化充值订单号 2024-08-03 01:28:18 +08:00
CalciumIon
e504665f68 feat: 优化Gemini模型版本获取逻辑 2024-08-02 17:23:59 +08:00
Calcium-Ion
54657ec27b Merge pull request #405 from utopeadia/main
Modify the GEMINI version acquisition logic and add support for more gemini1.5pro/flash interfaces
2024-08-02 17:13:46 +08:00
Calcium-Ion
ae6b4e0be2 Merge pull request #399 from kakingone/main
add-mjp-discord-upload
2024-08-02 17:10:35 +08:00
HowieWu
fc0db4505c Update README.md
增加Gemini版本变量说明
2024-08-02 11:25:41 +08:00
HowieWu
22a98c5879 修改Gemini版本获取逻辑
使用GEMINI_MODEL_API环境变量覆盖默认版本映射,使用","分隔不同模型和版本
-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta,gemini-1.5-pro:v1beta,gemini-1.5-flash-latest:v1beta,gemini-1.5-flash-001:v1beta,gemini-1.5-flash:v1beta,gemini-ultra:v1beta,gemini-1.5-pro-exp-0801:v1beta"
2024-08-02 11:20:26 +08:00
CalciumIon
f8f15bd1d0 fix: rpm模糊查询 2024-08-01 18:14:10 +08:00
CalciumIon
b7690fe17d fix: 日志模糊查询 2024-08-01 18:06:25 +08:00
CalciumIon
58b4c237a4 feat: 优化rpm查询 2024-08-01 17:39:18 +08:00
CalciumIon
54f6e660f1 feat: 优化日志查初始时间 2024-08-01 17:36:26 +08:00
CalciumIon
3b1745c712 feat: 优化日志查询条件 2024-08-01 16:33:59 +08:00
CalciumIon
c92ab3b569 feat: 日志新增rpm和tpm数据。(close #384) 2024-08-01 16:13:08 +08:00
wozulong
918690701d merge upstream
Signed-off-by: wozulong <>
2024-08-01 15:12:54 +08:00
CalciumIon
1501ccb919 fix: error channel name on notify. #338 2024-07-31 18:20:13 +08:00
Calcium-Ion
7f2a2a7de0 Merge pull request #400 from OswinWu/feat-gitignore-web-dist
feat: ignore npm build dir
2024-07-31 17:14:50 +08:00
Calcium-Ion
cce7d0258f Merge pull request #401 from HynoR/main
Support cloudflare llama3.1-8b
2024-07-31 17:14:30 +08:00
TAKO
c5e8d7ec20 Support cloudflare llama3.1-8b 2024-07-31 17:11:25 +08:00
OswinWu
fe16d51fe4 feat: ignore npm build dir 2024-07-31 16:50:19 +08:00
kakingone
2100d8ee0c addupload 2024-07-31 15:48:51 +08:00
CalciumIon
fbce36238e feat: support dify agent 2024-07-30 17:30:40 +08:00
CalciumIon
a6b6bcfe00 chore: remove useless code 2024-07-28 01:12:26 +08:00
CalciumIon
07e55cc999 chore: update token page 2024-07-28 00:05:53 +08:00
CalciumIon
b16e6bf423 fix: panic when get model ratio (close #392) 2024-07-27 18:09:09 +08:00
CalciumIon
b7bc205b73 feat: print user id when error 2024-07-27 17:55:36 +08:00
CalciumIon
88cc88c5d0 feat: support ollama tools 2024-07-27 17:51:05 +08:00
CalciumIon
ab1d61d910 feat: print user id when error 2024-07-27 17:47:30 +08:00
Calcium-Ion
d4a5df7373 Merge pull request #391 from OswinWu/fix-outlook-smtp
[fix] fix send email error using outlook smtp
2024-07-26 20:24:08 +08:00
CalciumIon
9e610c9429 fix: image quota (close #382) 2024-07-26 18:51:34 +08:00
Oswin
da490db6d3 [fix] fix send email error using outlook smtp 2024-07-26 17:47:36 +08:00
1808837298@qq.com
b8291dcd13 fix: gemini 2024-07-23 18:34:16 +08:00
Calcium-Ion
b0d9756c14 Merge pull request #380 from crabkun/main
fix: 修复aws claude渠道panic的问题
2024-07-23 18:22:27 +08:00
Calcium-Ion
9dc07a8585 Merge pull request #383 from Yan-Zero/main
fix: the base64 format image_url for gemini
2024-07-23 18:22:06 +08:00
1808837298@qq.com
caaecb8d54 fix: first login error (close #385) 2024-07-23 18:25:43 +08:00
Yan Tau
b9454c3f14 fix: the base64 format image_url for gemini 2024-07-22 21:20:23 +08:00
wozulong
f0008e95fa merge upstream
Signed-off-by: wozulong <>
2024-07-22 10:51:49 +08:00
crabkun
96bdf97194 fix: 修复aws claude渠道panic的问题 2024-07-21 01:27:29 +08:00
CalciumIon
3875b141c6 fix: gemini stream finish reason (close #378) 2024-07-19 17:16:20 +08:00
CalciumIon
12da7f64cd feat: update log search 2024-07-19 16:04:56 +08:00
CalciumIon
9ef3212e6c feat: update stream_options again 2024-07-19 15:06:07 +08:00
CalciumIon
20da8228df feat: update stream_options 2024-07-19 14:46:25 +08:00
CalciumIon
436d08b48f feat: update stream_options 2024-07-19 14:06:10 +08:00
CalciumIon
ce815a98d0 fix: 修复nginx缓存导致串用户问题 2024-07-19 13:39:05 +08:00
CalciumIon
e2cf6b1e14 feat: support gpt-4o-mini image tokens 2024-07-19 12:59:37 +08:00
wozulong
7a249b206d merge upstream
Signed-off-by: wozulong <>
2024-07-19 10:58:21 +08:00
CalciumIon
733b374596 Update README.md 2024-07-19 03:06:20 +08:00
CalciumIon
56afe47aa8 feat: update model ratio 2024-07-19 01:34:00 +08:00
CalciumIon
67b74ada00 feat: update model ratio 2024-07-19 01:29:08 +08:00
CalciumIon
e84300f4ae chore: gopool 2024-07-19 01:07:37 +08:00
CalciumIon
c9100b219f feat: support ali image 2024-07-19 00:45:52 +08:00
CalciumIon
f96291a25a feat: support gemini tool calling (close #368) 2024-07-18 20:28:47 +08:00
CalciumIon
14bf865034 feat: add UPDATE_TASK env 2024-07-18 17:26:21 +08:00
CalciumIon
70491ea1bb fix: image relay quota 2024-07-18 17:12:28 +08:00
CalciumIon
ae00a99cf5 feat: 媒体请求计费选项 2024-07-18 17:04:19 +08:00
Calcium-Ion
a6a2d52fab Merge pull request #372 from Calcium-Ion/image
refactor: image relay
2024-07-18 00:41:48 +08:00
CalciumIon
fae918c055 chore: log format 2024-07-18 00:41:31 +08:00
CalciumIon
11fd993574 feat: support claude tool calling 2024-07-18 00:36:05 +08:00
CalciumIon
b0d5491a2a refactor: image relay 2024-07-17 23:50:37 +08:00
Calcium-Ion
0f94ff47b5 Merge pull request #367 from Calcium-Ion/audio
feat: support cloudflare tts
2024-07-17 17:34:59 +08:00
Calcium-Ion
9a8fd5cd6f Merge pull request #371 from daggeryu/patch-1
fix: embedding model dimensions
2024-07-17 17:03:42 +08:00
CalciumIon
7a0beb5793 fix: distribute panic 2024-07-17 17:01:25 +08:00
CalciumIon
e3b83f886f fix: try to fix panic #369 2024-07-17 16:43:55 +08:00
daggeryu
fd87260209 fix: embedding model dimensions 2024-07-17 16:40:44 +08:00
CalciumIon
4d0d18931d fix: try to fix panic #369 2024-07-17 16:38:56 +08:00
CalciumIon
86ca533f7a fix: fix bug 2024-07-16 23:40:52 +08:00
CalciumIon
ebb9b675b6 feat: support cloudflare audio 2024-07-16 23:24:47 +08:00
CalciumIon
bcc7f3edb2 refactor: audio relay 2024-07-16 22:07:10 +08:00
CalciumIon
11856ab39e Update README.md 2024-07-16 17:02:37 +08:00
CalciumIon
eb9b4b07ad feat: update register page 2024-07-16 15:48:56 +08:00
CalciumIon
963985e76c chore: update model radio 2024-07-16 14:54:03 +08:00
CalciumIon
a3880d558a chore: mj 2024-07-15 22:14:30 +08:00
CalciumIon
ba27da9e2c fix: try to fix mj 2024-07-15 22:09:11 +08:00
CalciumIon
e262a9bd2c chore: openai stream 2024-07-15 22:07:50 +08:00
CalciumIon
9bbe8e7d1b fix: 日志详情非消费类型显示错误 2024-07-15 20:23:19 +08:00
CalciumIon
e2b9061650 fix: openai stream response 2024-07-15 19:06:13 +08:00
CalciumIon
220ab412e2 fix: openai response time 2024-07-15 18:14:07 +08:00
CalciumIon
7029065892 refactor: 重构流模式逻辑 2024-07-15 18:04:05 +08:00
CalciumIon
0f687aab9a fix: azure stream options 2024-07-15 16:05:30 +08:00
Calcium-Ion
5e936b3923 Merge pull request #363 from dalefengs/main
fix: http code is not properly disabled channel
2024-07-14 15:35:27 +08:00
FENG
d55cb35c1c fix: http code is not properly disabled 2024-07-14 01:21:05 +08:00
Calcium-Ion
5be4cbcaaf Merge pull request #362 from dalefengs/main
fix: channel timeout auto-ban and auto-enable
2024-07-14 00:30:17 +08:00
FENG
e67aa370bc fix: channel timeout auto-ban and auto-enable 2024-07-14 00:14:07 +08:00
wozulong
0cc7f5cca6 merge upstream
Signed-off-by: wozulong <>
2024-07-11 14:10:10 +08:00
wozulong
ed86ec8b59 Merge remote-tracking branch 'upstream/main' 2024-07-05 11:22:44 +08:00
wozulong
895ee09b33 merge upstream
Signed-off-by: wozulong <>
2024-07-01 15:14:22 +08:00
wozulong
61006bed9e merge upstream
Signed-off-by: wozulong <>
2024-06-25 18:11:23 +08:00
wozulong
58591b8c2a fix linuxdo oauth
Signed-off-by: wozulong <>
2024-05-31 13:12:35 +08:00
wozulong
69b9950aa0 merge upstream
Signed-off-by: wozulong <>
2024-05-31 11:14:25 +08:00
wozulong
bcf1b160d1 update
Signed-off-by: wozulong <>
2024-05-28 12:09:19 +08:00
wozulong
1e0e40aa69 update
Signed-off-by: wozulong <>
2024-05-28 12:04:10 +08:00
wozulong
74392efbed merge upstream
Signed-off-by: wozulong <>
2024-05-28 11:57:05 +08:00
wozulong
cc020d6a40 fix req model
Signed-off-by: wozulong <>
2024-05-24 16:46:11 +08:00
wozulong
daa1741aed merge upstream
Signed-off-by: wozulong <>
2024-05-22 11:38:59 +08:00
wozulong
a25bcaa58f fix the type of the logprobs
Signed-off-by: wozulong <>
2024-05-22 11:30:38 +08:00
wozulong
d34b601dae Merge remote-tracking branch 'upstream/main' 2024-05-19 16:03:53 +08:00
wozulong
9932962320 merge upstream
Signed-off-by: wozulong <>
2024-05-16 19:00:58 +08:00
wozulong
00d6cda9ed refine completions req type
Signed-off-by: wozulong <>
2024-05-16 18:54:21 +08:00
wozulong
5ffb520363 merge upstream
Signed-off-by: wozulong <>
2024-05-15 12:36:29 +08:00
wozulong
e0f80cdb8f add gpt-4o
Signed-off-by: wozulong <>
2024-05-15 11:49:35 +08:00
wozulong
7a7a923504 add gpt-4o tokenizer
Signed-off-by: wozulong <>
2024-05-14 14:47:57 +08:00
wozulong
4f6c171a08 add gpt-4o model ratio
Signed-off-by: wozulong <>
2024-05-14 02:38:44 +08:00
wozulong
310f8c247e fix auth
Signed-off-by: wozulong <>
2024-05-11 13:47:51 +08:00
wozulong
811019bf5c lint fix
Signed-off-by: wozulong <>
2024-05-10 14:15:38 +08:00
wozulong
0ed6600437 Merge remote-tracking branch 'upstream/main' 2024-05-10 14:08:43 +08:00
wozulong
1fe7f14d57 merge upstream
Signed-off-by: wozulong <>
2024-04-28 14:04:39 +08:00
wozulong
a7bafec1bf merge upstream
Signed-off-by: wozulong <>
2024-04-28 14:04:19 +08:00
wozulong
ed951b3974 merge upstream
Signed-off-by: wozulong <>
2024-04-25 16:01:18 +08:00
wozulong
c74e43b8fd merge upstream
Signed-off-by: wozulong <>
2024-04-18 15:18:28 +08:00
wozulong
03bd9b0cc4 merge upstream
Signed-off-by: wozulong <>
2024-04-10 11:31:20 +08:00
wozulong
149902bd8a add gpt-4-turbo
Signed-off-by: wozulong <>
2024-04-10 11:21:12 +08:00
wozulong
cd3ed22045 Merge remote-tracking branch 'upstream/main' 2024-04-06 20:24:20 +08:00
wozulong
2329d387ca fix icon
Signed-off-by: wozulong <>
2024-04-06 20:24:07 +08:00
wozulong
7d18a8e2a9 merge upstream
Signed-off-by: wozulong <>
2024-04-05 22:10:07 +08:00
wozulong
80af3718d0 lint fix
Signed-off-by: wozulong <>
2024-04-01 10:16:35 +08:00
wozulong
77ea6bec46 Merge remote-tracking branch 'upstream/main' 2024-04-01 10:15:48 +08:00
wozulong
c0ab8ae953 update github action
Signed-off-by: wozulong <>
2024-03-27 18:04:27 +08:00
wozulong
923c2dee32 merge upstream
Signed-off-by: wozulong <>
2024-03-27 18:01:43 +08:00
wozulong
ea17a46d8e fixed pagination issue in the log list
Signed-off-by: wozulong <>
2024-03-25 16:36:57 +08:00
wozulong
bfe9e5d25a Merge remote-tracking branch 'upstream/main' 2024-03-25 15:49:12 +08:00
wozulong
831ff47254 Merge remote-tracking branch 'upstream/main' 2024-03-24 00:18:02 +08:00
wozulong
11eaba6b5d fix gpt-3.5-turbo points
Signed-off-by: wozulong <>
2024-03-24 00:17:49 +08:00
wozulong
c2e4ec25c8 merge upstream
Signed-off-by: wozulong <>
2024-03-23 22:51:18 +08:00
wozulong
8537f10412 save stripe custome id
Signed-off-by: wozulong <>
2024-03-23 21:59:15 +08:00
wozulong
71d60eeef7 merge upstream
Signed-off-by: wozulong <>
2024-03-22 18:09:13 +08:00
wozulong
247ae0988f replace epay with stripe
Signed-off-by: wozulong <>
2024-03-22 18:00:20 +08:00
wozulong
0907fa6994 fix next uid
Signed-off-by: wozulong <>
2024-03-21 15:02:47 +08:00
wozulong
9855343aa8 Merge remote-tracking branch 'upstream/main' 2024-03-21 15:01:51 +08:00
wozulong
bd50fde268 Merge remote-tracking branch 'upstream/main' 2024-03-21 13:10:21 +08:00
wozulong
4267de5642 fix yarn timeout
Signed-off-by: wozulong <>
2024-03-20 19:14:17 +08:00
wozulong
8b55116563 update build files
Signed-off-by: wozulong <>
2024-03-20 18:31:22 +08:00
wozulong
f35e63e3f3 limit 'LINUX DO' trust level now available
Signed-off-by: wozulong <>
2024-03-20 16:54:38 +08:00
wozulong
17c409de23 update gh action
Signed-off-by: wozulong <>
2024-03-20 14:11:21 +08:00
wozulong
e4753e7411 synced with upstream
Signed-off-by: wozulong <>
2024-03-20 13:52:10 +08:00
paderlol
9adefa80b9 Optimizing Docker image builds (#1)
Co-authored-by: pader.zhang <pader.zhang@starlight-sms.com>
2024-03-14 23:10:05 -07:00
wozulong
4ce2381182 update issue template
Signed-off-by: wozulong <>
2024-03-14 20:02:22 +08:00
我秦始皇
62afc21ea5 Merge branch 'Calcium-Ion:main' into main 2024-03-14 03:56:31 -07:00
wozulong
7ddb7c586d 1. add LINUX DO oauth
2. fix oauth reg aff issue

Signed-off-by: wozulong <>
2024-03-14 18:53:54 +08:00
139 changed files with 8193 additions and 3068 deletions

View File

@@ -1,5 +1,5 @@
blank_issues_enabled: false blank_issues_enabled: false
contact_links: contact_links:
- name: 项目群聊 - name: 交流社区
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg url: https://linux.do
about: QQ 群629454374 about: 项目交流社区

View File

@@ -1,9 +1,6 @@
name: Publish Docker image (amd64) name: Publish Docker image (amd64)
on: on:
push:
tags:
- '*'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
name: name:
@@ -42,7 +39,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
calciumion/new-api pengzhile/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images

View File

@@ -49,7 +49,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
calciumion/new-api pengzhile/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images

3
.gitignore vendored
View File

@@ -5,4 +5,5 @@ upload
*.db *.db
build build
*.db-journal *.db-journal
logs logs
web/dist

View File

@@ -7,7 +7,7 @@ all: build-frontend start-backend
build-frontend: build-frontend:
@echo "Building frontend..." @echo "Building frontend..."
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build @cd $(FRONTEND_DIR) && yarn install --network-timeout 1000000 && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) yarn build
start-backend: start-backend:
@echo "Starting backend dev server..." @echo "Starting backend dev server..."

View File

@@ -2,15 +2,21 @@
# New API # New API
> [!NOTE] > [!NOTE]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。 > 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。 > [!IMPORTANT]
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> [!NOTE] > [!TIP]
> 最新版Docker镜像 calciumion/new-api:latest > 最新版Docker镜像`calciumion/new-api:latest`
> 更新指令 docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR > 默认账号root 密码123456
> 更新指令:
> ```
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
## 主要变更 ## 主要变更
此分叉版本的主要变更如下: 此分叉版本的主要变更如下:
@@ -18,9 +24,9 @@
1. 全新的UI界面部分界面还待更新 1. 全新的UI界面部分界面还待更新
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md) 2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口: 3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付 + [x] 易支付
4. 支持用key查询使用额度: 4. 支持用key查询使用额度:
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用 + 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
5. 渠道显示已使用额度,支持指定组织访问 5. 渠道显示已使用额度,支持指定组织访问
6. 分页支持选择每页显示数量 6. 分页支持选择每页显示数量
7. 兼容原版One API的数据库可直接使用原版数据库one-api.db 7. 兼容原版One API的数据库可直接使用原版数据库one-api.db
@@ -38,7 +44,7 @@
## 模型支持 ## 模型支持
此版本额外支持以下模型: 此版本额外支持以下模型:
1. 第三方模型 **gps** gpt-4-gizmo-* 1. 第三方模型 **gps** gpt-4-gizmo-*, g-*
2. 智谱glm-4vglm-4v识图 2. 智谱glm-4vglm-4v识图
3. Anthropic Claude 3 3. Anthropic Claude 3
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改 4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
@@ -49,31 +55,17 @@
9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md) 9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md)
10. Dify 10. Dify
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。 您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
### 为什么有的时候没有重试
这些错误码不会重试400504524
### 我想让400也重试
`渠道->编辑`中,将`状态码复写`改为
```json
{
"400": "500"
}
```
可以实现400错误转为500错误从而重试
## 比原版One API多出的配置 ## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` 可选值为 `true``false` - `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`覆盖客户端stream_options参数请求上游返回流模式usage目前仅支持 `OpenAI` 渠道类型 - `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
- `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
## 部署 ## 部署
### 部署要求 ### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机) - 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
@@ -96,8 +88,25 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 注意数据库要开启远程访问并且只允许服务器IP访问 # 注意数据库要开启远程访问并且只允许服务器IP访问
``` ```
### 默认账号密码
默认账号root 密码123456 ## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
### 为什么有的时候没有重试
这些错误码不会重试400504524
### 我想让400也重试
`渠道->编辑`中,将`状态码复写`改为
```json
{
"400": "500"
}
```
可以实现400错误转为500错误从而重试
## Midjourney接口设置文档 ## Midjourney接口设置文档
[对接文档](Midjourney.md) [对接文档](Midjourney.md)

View File

@@ -9,9 +9,20 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// Pay Settings
var StripeApiSecret = ""
var StripeWebhookSecret = ""
var StripePriceId = ""
var PaymentEnabled = false
var StripeUnitPrice = 8.0
var MinTopUp = 5
var StartTime = time.Now().Unix() // unit: second var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API" var SystemName = "New API"
var ServerAddress = "http://localhost:3000"
var OutProxyUrl = ""
var Footer = "" var Footer = ""
var Logo = "" var Logo = ""
var TopUpLink = "" var TopUpLink = ""
@@ -41,10 +52,12 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false var GitHubOAuthEnabled = false
var LinuxDoOAuthEnabled = false
var WeChatAuthEnabled = false var WeChatAuthEnabled = false
var TelegramOAuthEnabled = false var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false var TurnstileCheckEnabled = false
var RegisterEnabled = true var RegisterEnabled = true
var UserSelfDeletionEnabled = false
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制 var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制 var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
@@ -75,6 +88,10 @@ var SMTPToken = ""
var GitHubClientId = "" var GitHubClientId = ""
var GitHubClientSecret = "" var GitHubClientSecret = ""
var LinuxDoClientId = ""
var LinuxDoClientSecret = ""
var LinuxDoMinLevel = 0
var WeChatServerAddress = "" var WeChatServerAddress = ""
var WeChatServerToken = "" var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = "" var WeChatAccountQRCodeImageURL = ""
@@ -176,6 +193,12 @@ const (
ChannelStatusAutoDisabled = 3 ChannelStatusAutoDisabled = 3
) )
const (
TopUpStatusPending = "pending"
TopUpStatusSuccess = "success"
TopUpStatusExpired = "expired"
)
const ( const (
ChannelTypeUnknown = 0 ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1 ChannelTypeOpenAI = 1

View File

@@ -0,0 +1,32 @@
package common
import (
"errors"
"net/smtp"
)
type outlookAuth struct {
username, password string
}
func LoginAuth(username, password string) smtp.Auth {
return &outlookAuth{username, password}
}
func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte{}, nil
}
func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:":
return []byte(a.username), nil
case "Password:":
return []byte(a.password), nil
default:
return nil, errors.New("unknown fromServer")
}
}
return nil, nil
}

View File

@@ -62,6 +62,9 @@ func SendEmail(subject string, receiver string, content string) error {
if err != nil { if err != nil {
return err return err
} }
} else if strings.HasSuffix(SMTPAccount, "outlook.com") {
auth = LoginAuth(SMTPAccount, SMTPToken)
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
} else { } else {
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
} }

84
common/hash.go Normal file
View File

@@ -0,0 +1,84 @@
package common
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"math/rand"
"time"
)
func Sha256Raw(data string) []byte {
h := sha256.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1Raw(data []byte) []byte {
h := sha1.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1(data string) string {
return hex.EncodeToString(Sha1Raw([]byte(data)))
}
func HmacSha256Raw(message, key []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(message)
return h.Sum(nil)
}
func HmacSha256(message, key string) string {
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
}
func RandomBytes(length int) []byte {
rand.Seed(time.Now().UnixNano())
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return b
}
func RandomString(length int) string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomHex(length int) string {
const chars = "abcdef0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomNumber(length int) string {
const chars = "0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomUUID() string {
all := RandomHex(32)
return all[:8] + "-" + all[8:12] + "-" + all[12:16] + "-" + all[16:20] + "-" + all[20:]
}

View File

@@ -1,4 +1,4 @@
package service package common
import ( import (
"bytes" "bytes"
@@ -8,7 +8,6 @@ import (
"golang.org/x/image/webp" "golang.org/x/image/webp"
"image" "image"
"io" "io"
"one-api/common"
"strings" "strings"
) )
@@ -31,9 +30,24 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
return config, format, base64String, err return config, format, base64String, err
} }
func IsImageUrl(url string) (bool, error) {
resp, err := ProxiedHttpHead(url, OutProxyUrl)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
// GetImageFromUrl 获取图片的类型和base64编码的数据 // GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) { func GetImageFromUrl(url string) (mimeType string, data string, err error) {
resp, err := DoImageRequest(url) isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := ProxiedHttpGet(url, OutProxyUrl)
if err != nil { if err != nil {
return return
} }
@@ -52,9 +66,9 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
} }
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := DoImageRequest(imageUrl) response, err := ProxiedHttpGet(imageUrl, OutProxyUrl)
if err != nil { if err != nil {
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err return image.Config{}, "", err
} }
defer response.Body.Close() defer response.Body.Close()
@@ -66,7 +80,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
var readData []byte var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
// 从response.Body读取更多的数据直到达到当前的限制 // 从response.Body读取更多的数据直到达到当前的限制
additionalData := make([]byte, limit-int64(len(readData))) additionalData := make([]byte, limit-int64(len(readData)))
@@ -92,11 +106,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
config, format, err := image.DecodeConfig(reader) config, format, err := image.DecodeConfig(reader)
if err != nil { if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
common.SysLog(err.Error()) SysLog(err.Error())
config, err = webp.DecodeConfig(reader) config, err = webp.DecodeConfig(reader)
if err != nil { if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
common.SysLog(err.Error()) SysLog(err.Error())
} }
format = "webp" format = "webp"
} }

View File

@@ -2,7 +2,6 @@ package common
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
@@ -100,12 +99,10 @@ func LogQuota(quota int) string {
} }
} }
// LogJson 仅供测试使用 only for test func LogQuotaF(quota float64) string {
func LogJson(ctx context.Context, msg string, obj any) { if DisplayInCurrencyEnabled {
jsonStr, err := json.Marshal(obj) return fmt.Sprintf("%.6f 额度", quota/QuotaPerUnit)
if err != nil { } else {
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) return fmt.Sprintf("%d 点额度", int64(quota))
return
} }
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
} }

View File

@@ -3,6 +3,7 @@ package common
import ( import (
"encoding/json" "encoding/json"
"strings" "strings"
"sync"
) )
// from songquanpeng/one-api // from songquanpeng/one-api
@@ -22,39 +23,41 @@ const (
var defaultModelRatio = map[string]float64{ var defaultModelRatio = map[string]float64{
//"midjourney": 50, //"midjourney": 50,
"gpt-4-gizmo-*": 15, "gpt-4-gizmo-*": 15,
"gpt-4-all": 15, "g-*": 15,
"gpt-4o-all": 15, "gpt-4": 15,
"gpt-4": 15, "gpt-4-0314": 15,
//"gpt-4-0314": 15, //deprecated "gpt-4-0613": 15,
"gpt-4-0613": 15, "gpt-4-32k": 30,
"gpt-4-32k": 30, "gpt-4-32k-0314": 30,
//"gpt-4-32k-0314": 30, //deprecated "gpt-4-32k-0613": 30,
"gpt-4-32k-0613": 30, "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens "gpt-4o-mini-2024-07-18": 0.075,
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens "gpt-4o": 2.5, // $0.005 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens "gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens "gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.01 / 1K tokens "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo": 5, // $0.01 / 1K tokens "gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
//"gpt-3.5-turbo-0301": 0.75, //deprecated "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-16k-0613": 1.5, "gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens "gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-0125": 0.25, "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens "gpt-3.5-turbo-0125": 0.25,
"text-ada-001": 0.2, "babbage-002": 0.2, // $0.0004 / 1K tokens
"text-babbage-001": 0.25, "davinci-002": 1, // $0.002 / 1K tokens
"text-curie-001": 1, "text-ada-001": 0.2,
//"text-davinci-002": 10, "text-babbage-001": 0.25,
//"text-davinci-003": 10, "text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
@@ -76,9 +79,9 @@ var defaultModelRatio = map[string]float64{
"claude-2.0": 4, // $8 / 1M tokens "claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens "claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens "claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens "claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5, "claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB, "ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB, "ERNIE-3.5-8K-0205": 0.024 * RMB,
@@ -105,12 +108,13 @@ var defaultModelRatio = map[string]float64{
"gemini-1.0-pro-latest": 1, "gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1, "gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1, "gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens "glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 7.143, // ¥0.1 / 1k tokens "glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572, "glm-3-turbo": 0.3572,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens "qwen-plus": 10, // ¥0.14 / 1k tokens
@@ -158,8 +162,12 @@ var defaultModelRatio = map[string]float64{
} }
var defaultModelPrice = map[string]float64{ var defaultModelPrice = map[string]float64{
"suno_music": 0.1,
"suno_lyrics": 0.01,
"dall-e-2": 0.02,
"dall-e-3": 0.04, "dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1, "gpt-4-gizmo-*": 0.1,
"g-*": 0.1,
"mj_imagine": 0.1, "mj_imagine": 0.1,
"mj_variation": 0.1, "mj_variation": 0.1,
"mj_reroll": 0.1, "mj_reroll": 0.1,
@@ -175,22 +183,38 @@ var defaultModelPrice = map[string]float64{
"mj_describe": 0.05, "mj_describe": 0.05,
"mj_upscale": 0.05, "mj_upscale": 0.05,
"swap_face": 0.05, "swap_face": 0.05,
"mj_upload": 0.05,
} }
var modelPrice map[string]float64 = nil var (
var modelRatio map[string]float64 = nil modelPriceMap = make(map[string]float64)
modelPriceMapMutex = sync.RWMutex{}
)
var (
modelRatioMap map[string]float64 = nil
modelRatioMapMutex = sync.RWMutex{}
)
var CompletionRatio map[string]float64 = nil var CompletionRatio map[string]float64 = nil
var defaultCompletionRatio = map[string]float64{ var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2, "gpt-4-gizmo-*": 2,
"g-*": 2,
"gpt-4-all": 2, "gpt-4-all": 2,
"gpt-4o-all": 2,
}
func GetModelPriceMap() map[string]float64 {
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
if modelPriceMap == nil {
modelPriceMap = defaultModelPrice
}
return modelPriceMap
} }
func ModelPrice2JSONString() string { func ModelPrice2JSONString() string {
if modelPrice == nil { GetModelPriceMap()
modelPrice = defaultModelPrice jsonBytes, err := json.Marshal(modelPriceMap)
}
jsonBytes, err := json.Marshal(modelPrice)
if err != nil { if err != nil {
SysError("error marshalling model price: " + err.Error()) SysError("error marshalling model price: " + err.Error())
} }
@@ -198,19 +222,21 @@ func ModelPrice2JSONString() string {
} }
func UpdateModelPriceByJSONString(jsonStr string) error { func UpdateModelPriceByJSONString(jsonStr string) error {
modelPrice = make(map[string]float64) modelPriceMapMutex.Lock()
return json.Unmarshal([]byte(jsonStr), &modelPrice) defer modelPriceMapMutex.Unlock()
modelPriceMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
} }
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false // GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) { func GetModelPrice(name string, printErr bool) (float64, bool) {
if modelPrice == nil { GetModelPriceMap()
modelPrice = defaultModelPrice
}
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
} }
price, ok := modelPrice[name] price, ok := modelPriceMap[name]
if !ok { if !ok {
if printErr { if printErr {
SysError("model price not found: " + name) SysError("model price not found: " + name)
@@ -220,18 +246,18 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
return price, true return price, true
} }
func GetModelPriceMap() map[string]float64 { func GetModelRatioMap() map[string]float64 {
if modelPrice == nil { modelRatioMapMutex.Lock()
modelPrice = defaultModelPrice defer modelRatioMapMutex.Unlock()
if modelRatioMap == nil {
modelRatioMap = defaultModelRatio
} }
return modelPrice return modelRatioMap
} }
func ModelRatio2JSONString() string { func ModelRatio2JSONString() string {
if modelRatio == nil { GetModelRatioMap()
modelRatio = defaultModelRatio jsonBytes, err := json.Marshal(modelRatioMap)
}
jsonBytes, err := json.Marshal(modelRatio)
if err != nil { if err != nil {
SysError("error marshalling model ratio: " + err.Error()) SysError("error marshalling model ratio: " + err.Error())
} }
@@ -239,18 +265,20 @@ func ModelRatio2JSONString() string {
} }
func UpdateModelRatioByJSONString(jsonStr string) error { func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatio = make(map[string]float64) modelRatioMapMutex.Lock()
return json.Unmarshal([]byte(jsonStr), &modelRatio) defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
} }
func GetModelRatio(name string) float64 { func GetModelRatio(name string) float64 {
if modelRatio == nil { GetModelRatioMap()
modelRatio = defaultModelRatio
}
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
} }
ratio, ok := modelRatio[name] ratio, ok := modelRatioMap[name]
if !ok { if !ok {
SysError("model ratio not found: " + name) SysError("model ratio not found: " + name)
return 30 return 30
@@ -289,29 +317,37 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
func GetCompletionRatio(name string) float64 { func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
} }
if strings.HasPrefix(name, "gpt-3.5") { if strings.HasPrefix(name, "gpt-3.5") {
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { if strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
return 3 return 3
} }
if strings.HasSuffix(name, "1106") { if strings.HasSuffix(name, "1106") {
return 2 return 2
} }
if name == "gpt-3.5-turbo" {
return 3
}
return 4.0 / 3.0 return 4.0 / 3.0
} }
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") { if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" {
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4o") { if strings.HasPrefix(name, "gpt-4o-mini") || "gpt-4o-2024-08-06" == name {
return 4
}
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || strings.HasPrefix(name, "gpt-4o") {
return 3 return 3
} }
return 2 return 2
} }
if strings.Contains(name, "claude-instant-1") { if strings.HasPrefix(name, "claude-instant-1") {
return 3 return 3
} else if strings.Contains(name, "claude-2") { } else if strings.HasPrefix(name, "claude-2") {
return 3 return 3
} else if strings.Contains(name, "claude-3") { } else if strings.HasPrefix(name, "claude-3") {
return 5 return 5
} }
if strings.HasPrefix(name, "mistral-") { if strings.HasPrefix(name, "mistral-") {

73
common/str.go Normal file
View File

@@ -0,0 +1,73 @@
package common
import (
"encoding/json"
"math/rand"
"strconv"
"unsafe"
)
func GetStringIfEmpty(str string, defaultValue string) string {
if str == "" {
return defaultValue
}
return str
}
func GetRandomString(length int) string {
//rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func MapToJsonStr(m map[string]interface{}) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func MapToJsonStrFloat(m map[string]float64) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
if err != nil {
return nil
}
return m
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}
func StringsContains(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
}
return false
}
// StringToByteSlice []byte only read, panic on append
func StringToByteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}

View File

@@ -1,19 +1,22 @@
package common package common
import ( import (
"encoding/json" "context"
"errors"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/net/proxy"
"html/template" "html/template"
"log" "log"
"math/rand" "math/rand"
"net" "net"
"net/http"
"net/url"
"os/exec" "os/exec"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"unsafe"
) )
func OpenBrowser(url string) { func OpenBrowser(url string) {
@@ -159,15 +162,6 @@ func GenerateKey() string {
return string(key) return string(key)
} }
func GetRandomString(length int) string {
//rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomInt(max int) int { func GetRandomInt(max int) int {
//rand.Seed(time.Now().UnixNano()) //rand.Seed(time.Now().UnixNano())
return rand.Intn(max) return rand.Intn(max)
@@ -194,56 +188,60 @@ func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id) return fmt.Sprintf("%s (request id: %s)", message, id)
} }
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}
func StringsContains(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
}
return false
}
// StringToByteSlice []byte only read, panic on append
func StringToByteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}
func RandomSleep() { func RandomSleep() {
// Sleep for 0-3000 ms // Sleep for 0-3000 ms
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
} }
func MapToJsonStr(m map[string]interface{}) string { func GetProxiedHttpClient(proxyUrl string) (*http.Client, error) {
bytes, err := json.Marshal(m) if "" == proxyUrl {
if err != nil { return &http.Client{}, nil
return ""
} }
return string(bytes)
u, err := url.Parse(proxyUrl)
if err != nil {
return nil, err
}
if strings.HasPrefix(proxyUrl, "http") {
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(u),
},
}, nil
} else if strings.HasPrefix(proxyUrl, "socks") {
dialer, err := proxy.FromURL(u, proxy.Direct)
if err != nil {
return nil, err
}
return &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
},
},
}, nil
}
return nil, errors.New("unsupported proxy type")
} }
func MapToJsonStrFloat(m map[string]float64) string { func ProxiedHttpGet(url, proxyUrl string) (*http.Response, error) {
bytes, err := json.Marshal(m) client, err := GetProxiedHttpClient(proxyUrl)
if err != nil { if err != nil {
return "" return nil, err
} }
return string(bytes)
return client.Get(url)
} }
func StrToMap(str string) map[string]interface{} { func ProxiedHttpHead(url, proxyUrl string) (*http.Response, error) {
m := make(map[string]interface{}) client, err := GetProxiedHttpClient(proxyUrl)
err := json.Unmarshal([]byte(str), &m)
if err != nil { if err != nil {
return nil return nil, err
} }
return m
return client.Head(url)
} }

View File

@@ -1,7 +1,10 @@
package constant package constant
import ( import (
"fmt"
"one-api/common" "one-api/common"
"os"
"strings"
) )
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30) var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
@@ -9,3 +12,35 @@ var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
// ForceStreamOption 覆盖请求参数强制返回usage信息 // ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
}
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
}

View File

@@ -27,6 +27,7 @@ const (
MjActionLowVariation = "LOW_VARIATION" MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN" MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE" MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
) )
var MidjourneyModel2Action = map[string]string{ var MidjourneyModel2Action = map[string]string{
@@ -45,4 +46,5 @@ var MidjourneyModel2Action = map[string]string{
"mj_low_variation": MjActionLowVariation, "mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan, "mj_pan": MjActionPan,
"swap_face": MjActionSwapFace, "swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
} }

View File

@@ -1,9 +0,0 @@
package constant
var ServerAddress = "http://localhost:3000"
var WorkerUrl = ""
var WorkerValidKey = ""
func EnableWorker() bool {
return WorkerUrl != ""
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"io" "io"
"math" "math"
"net/http" "net/http"
@@ -25,7 +26,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) { func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
tik := time.Now() tik := time.Now()
if channel.Type == common.ChannelTypeMidjourney { if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil return errors.New("midjourney channel test is not supported"), nil
@@ -58,8 +59,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
modelMap := make(map[string]string) modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap) err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil { if err != nil {
openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
return err, &openaiErr
} }
if modelMap[testModel] != "" { if modelMap[testModel] != "" {
testModel = modelMap[testModel] testModel = modelMap[testModel]
@@ -86,9 +86,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
meta.UpstreamModelName = testModel meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
adaptor.Init(meta, *request) adaptor.Init(meta)
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
if err != nil { if err != nil {
return err, nil return err, nil
} }
@@ -103,12 +103,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
return err, nil return err, nil
} }
if resp != nil && resp.StatusCode != http.StatusOK { if resp != nil && resp.StatusCode != http.StatusOK {
err := relaycommon.RelayErrorHandler(resp) err := service.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
} }
usage, respErr := adaptor.DoResponse(c, resp, meta) usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil { if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error return fmt.Errorf("%s", respErr.Error.Message), respErr
} }
if usage == nil { if usage == nil {
return errors.New("usage is nil"), nil return errors.New("usage is nil"), nil
@@ -218,11 +218,11 @@ func testAllChannels(notify bool) error {
if disableThreshold == 0 { if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value disableThreshold = 10000000 // a impossible value
} }
go func() { gopool.Go(func() {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel, "") err, openaiWithStatusErr := testChannel(channel, "")
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
@@ -231,27 +231,29 @@ func testAllChannels(notify bool) error {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
ban = true ban = true
} }
if openaiErr != nil {
err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message)) // request error disables the channel
ban = true if openaiWithStatusErr != nil {
oaiErr := openaiWithStatusErr.Error
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
} }
// parse *int to bool // parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 { if !channel.GetAutoBan() {
ban = false ban = false
} }
if openaiErr != nil {
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{ // disable channel
StatusCode: -1, if ban && isChannelEnabled {
Error: *openaiErr, service.DisableChannel(channel.Id, channel.Name, err.Error())
LocalError: false,
}
if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
}
} }
// enable channel
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval) time.Sleep(common.RequestInterval)
} }
@@ -264,7 +266,7 @@ func testAllChannels(notify bool) error {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
} }
} }
}() })
return nil return nil
} }

View File

@@ -123,6 +123,8 @@ func GitHubOAuth(c *gin.Context) {
} }
} else { } else {
if common.RegisterEnabled { if common.RegisterEnabled {
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" { if githubUser.Name != "" {
user.DisplayName = githubUser.Name user.DisplayName = githubUser.Name
@@ -133,7 +135,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = common.RoleCommonUser user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

239
controller/linuxdo.go Normal file
View File

@@ -0,0 +1,239 @@
package controller
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
type LinuxDoOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type LinuxDoUser struct {
ID int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
form := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
}
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", "Basic "+auth)
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse LinuxDoOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res2.Body.Close()
var linuxdoUser LinuxDoUser
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
if err != nil {
return nil, err
}
if linuxdoUser.ID == 0 {
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
}
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
return nil, errors.New("用户 LINUX DO 信任等级不足!")
}
return &linuxdoUser, nil
}
func LinuxDoOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
err := user.FillUserByLinuxDoId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
affCode := c.Query("aff")
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
if linuxdoUser.Name != "" {
user.DisplayName = linuxdoUser.Name
} else {
user.DisplayName = linuxdoUser.Username
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 LINUX DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -24,7 +24,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel) logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -35,6 +35,7 @@ func GetAllLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"total": total,
"data": logs, "data": logs,
}) })
return return
@@ -58,7 +59,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize) logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -69,6 +70,7 @@ func GetUserLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"total": total,
"data": logs, "data": logs,
}) })
return return

View File

@@ -146,28 +146,26 @@ func UpdateMidjourneyTaskBulk() {
buttonStr, _ := json.Marshal(responseItem.Buttons) buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr) task.Buttons = string(buttonStr)
} }
shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%" task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId) if task.Quota != 0 {
if err != nil { shouldReturnQuota = true
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
} }
} }
err = task.Update() err = task.Update()
if err != nil { if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(task.Quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
} }
} }
} }
@@ -235,7 +233,7 @@ func GetAllMidjourney(c *gin.Context) {
} }
if constant.MjForwardUrlEnabled { if constant.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney
} }
} }
@@ -267,7 +265,7 @@ func GetUserMidjourney(c *gin.Context) {
} }
if constant.MjForwardUrlEnabled { if constant.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney
} }
} }

View File

@@ -38,6 +38,8 @@ func GetStatus(c *gin.Context) {
"email_verification": common.EmailVerificationEnabled, "email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled, "github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId, "github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
"linuxdo_client_id": common.LinuxDoClientId,
"telegram_oauth": common.TelegramOAuthEnabled, "telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName, "telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName, "system_name": common.SystemName,
@@ -45,9 +47,9 @@ func GetStatus(c *gin.Context) {
"footer_html": common.Footer, "footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled, "wechat_login": common.WeChatAuthEnabled,
"server_address": constant.ServerAddress, "server_address": common.ServerAddress,
"price": constant.Price, "stripe_unit_price": common.StripeUnitPrice,
"min_topup": constant.MinTopUp, "min_topup": common.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey, "turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink, "top_up_link": common.TopUpLink,
@@ -61,7 +63,7 @@ func GetStatus(c *gin.Context) {
"enable_data_export": common.DataExportEnabled, "enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime, "data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar, "default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "", "payment_enabled": common.PaymentEnabled,
"mj_notify_enabled": constant.MjNotifyEnabled, "mj_notify_enabled": constant.MjNotifyEnabled,
}, },
}) })
@@ -204,7 +206,7 @@ func SendPasswordResetEmail(c *gin.Context) {
} }
code := common.GenerateVerificationCode(0) code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName) subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ "<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+

View File

@@ -131,7 +131,7 @@ func init() {
} }
meta := &relaycommon.RelayInfo{ChannelType: i} meta := &relaycommon.RelayInfo{ChannelType: i}
adaptor := relay.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
adaptor.Init(meta, dto.GeneralOpenAIRequest{}) adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList() channelId2Models[i] = adaptor.GetModelList()
} }
} }

View File

@@ -50,6 +50,14 @@ func UpdateOption(c *gin.Context) {
}) })
return return
} }
case "LinuxDoOAuthEnabled":
if option.Value == "true" && common.LinuxDoClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 LINUX DO OAuth请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret",
})
return
}
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -2,6 +2,7 @@ package controller
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
@@ -22,13 +23,13 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
var err *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case relayconstant.RelayModeImagesGenerations: case relayconstant.RelayModeImagesGenerations:
err = relay.RelayImageHelper(c, relayMode) err = relay.ImageHelper(c, relayMode)
case relayconstant.RelayModeAudioSpeech: case relayconstant.RelayModeAudioSpeech:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranslation: case relayconstant.RelayModeAudioTranslation:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranscription: case relayconstant.RelayModeAudioTranscription:
err = relay.AudioHelper(c, relayMode) err = relay.AudioHelper(c)
case relayconstant.RelayModeRerank: case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode) err = relay.RerankHelper(c, relayMode)
default: default:
@@ -39,43 +40,35 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path) relayMode := constant.Path2RelayMode(c.Request.URL.Path)
retryTimes := common.RetryTimes
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)
channelId := c.GetInt("channel_id")
channelType := c.GetInt("channel_type")
group := c.GetString("group") group := c.GetString("group")
originalModel := c.GetString("original_model") originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode) var openaiErr *dto.OpenAIErrorWithStatusCode
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
if openaiErr != nil { for i := 0; i <= common.RetryTimes; i++ {
go processChannelError(c, channelId, channelType, openaiErr) channel, err := getChannel(c, group, originalModel, i)
} else {
retryTimes = 0
}
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil { if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break break
} }
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c) openaiErr = relayRequest(c, relayMode, channel)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
openaiErr = relayHandler(c, relayMode) if openaiErr == nil {
if openaiErr != nil { return // 成功处理请求,直接返回
go processChannelError(c, channelId, channel.Type, openaiErr) }
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
} }
} }
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 { if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr) common.LogInfo(c, retryLogStr)
} }
if openaiErr != nil { if openaiErr != nil {
@@ -89,7 +82,42 @@ func Relay(c *gin.Context) {
} }
} }
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
if retryCount == 0 {
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
if !autoBan {
autoBanInt = 0
}
return &model.Channel{
Id: c.GetInt("channel_id"),
Type: c.GetInt("channel_type"),
Name: c.GetString("channel_name"),
AutoBan: &autoBanInt,
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
if openaiErr == nil { if openaiErr == nil {
return false return false
} }
@@ -113,6 +141,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true return true
} }
if openaiErr.StatusCode == http.StatusBadRequest { if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
return true
}
return false return false
} }
if openaiErr.StatusCode == 408 { if openaiErr.StatusCode == 408 {
@@ -128,11 +160,11 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true return true
} }
func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) { func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban") // 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(channelType, err) && autoBan { if service.ShouldDisableChannel(channelType, err) && autoBan {
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message) service.DisableChannel(channelId, channelName, err.Error.Message)
} }
} }
@@ -208,14 +240,14 @@ func RelayTask(c *gin.Context) {
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil { if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break break
} }
channelId = channel.Id channelId = channel.Id
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel) c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel) middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c) requestBody, err := common.GetRequestBody(c)
@@ -225,7 +257,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 { if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr) common.LogInfo(c, retryLogStr)
} }
if taskErr != nil { if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests { if taskErr.StatusCode == http.StatusTooManyRequests {

97
controller/stripe.go Normal file
View File

@@ -0,0 +1,97 @@
package controller
import (
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/webhook"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
func StripeWebhook(c *gin.Context) {
payload, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
c.AbortWithStatus(http.StatusServiceUnavailable)
return
}
signature := c.GetHeader("Stripe-Signature")
endpointSecret := common.StripeWebhookSecret
event, err := webhook.ConstructEvent(payload, signature, endpointSecret)
if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err)
c.AbortWithStatus(http.StatusBadRequest)
return
}
switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(event)
case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(event)
default:
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
}
c.Status(http.StatusOK)
}
func sessionCompleted(event stripe.Event) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "complete" != status {
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
return
}
err := model.Recharge(referenceId, customerId)
if err != nil {
log.Println(err.Error(), referenceId)
return
}
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency"))
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
}
func sessionExpired(event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "expired" != status {
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
return
}
if "" == referenceId {
log.Println("未提供支付单号")
return
}
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Println("充值订单不存在", referenceId)
return
}
if topUp.Status != common.TopUpStatusPending {
log.Println("充值订单状态错误", referenceId)
}
topUp.Status = common.TopUpStatusExpired
err := topUp.Update()
if err != nil {
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
return
}
log.Println("充值订单已过期", referenceId)
}

View File

@@ -1,22 +1,20 @@
package controller package controller
import "C"
import ( import (
"fmt" "fmt"
"github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/samber/lo" "github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/checkout/session"
"log" "log"
"net/url"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/model" "one-api/model"
"one-api/service"
"strconv" "strconv"
"sync" "strings"
"time" "time"
) )
type EpayRequest struct { type PayRequest struct {
Amount int `json:"amount"` Amount int `json:"amount"`
PaymentMethod string `json:"payment_method"` PaymentMethod string `json:"payment_method"`
TopUpCode string `json:"top_up_code"` TopUpCode string `json:"top_up_code"`
@@ -27,196 +25,114 @@ type AmountRequest struct {
TopUpCode string `json:"top_up_code"` TopUpCode string `json:"top_up_code"`
} }
func GetEpayClient() *epay.Client { func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" { if !strings.HasPrefix(common.StripeApiSecret, "sk_") {
return nil return "", fmt.Errorf("无效的Stripe API密钥")
} }
withUrl, err := epay.NewClient(&epay.Config{
PartnerID: constant.EpayId, stripe.Key = common.StripeApiSecret
Key: constant.EpayKey,
}, constant.PayAddress) params := &stripe.CheckoutSessionParams{
ClientReferenceID: stripe.String(referenceId),
SuccessURL: stripe.String(common.ServerAddress + "/log"),
CancelURL: stripe.String(common.ServerAddress + "/topup"),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(common.StripePriceId),
Quantity: stripe.Int64(amount),
},
},
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
}
if "" == customerId {
if "" != email {
params.CustomerEmail = stripe.String(email)
}
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
} else {
params.Customer = stripe.String(customerId)
}
result, err := session.New(params)
if err != nil { if err != nil {
return nil return "", err
} }
return withUrl
return result.URL, nil
} }
func getPayMoney(amount float64, user model.User) float64 { func GetPayAmount(count float64) float64 {
if !common.DisplayInCurrencyEnabled { return count * common.StripeUnitPrice
amount = amount / common.QuotaPerUnit
}
// 别问为什么用float64问就是这么点钱没必要
topupGroupRatio := common.GetTopupGroupRatio(user.Group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
payMoney := amount * constant.Price * topupGroupRatio
return payMoney
} }
func getMinTopup() int { func GetChargedAmount(count float64, user model.User) float64 {
minTopup := constant.MinTopUp topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
if !common.DisplayInCurrencyEnabled { if topUpGroupRatio == 0 {
minTopup = minTopup * int(common.QuotaPerUnit) topUpGroupRatio = 1
} }
return minTopup
return count * topUpGroupRatio
} }
func RequestEpay(c *gin.Context) { func RequestPayLink(c *gin.Context) {
var req EpayRequest var req PayRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(200, gin.H{"message": err.Error(), "data": 10})
return return
} }
if req.Amount < getMinTopup() { if !common.PaymentEnabled {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())}) c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
return
}
if req.PaymentMethod != "stripe" {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
return
}
if req.Amount > 10000 {
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
payMoney := getPayMoney(float64(req.Amount), *user) chargedMoney := GetChargedAmount(float64(req.Amount), *user)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
var payType epay.PurchaseType reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), common.RandomString(4))
if req.PaymentMethod == "zfb" { referenceId := "ref_" + common.Sha1(reference)
payType = epay.Alipay
} payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, int64(req.Amount))
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = epay.WechatPay
}
callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType,
ServiceTradeNo: "A" + tradeNo,
Name: "B" + tradeNo,
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
Device: epay.PC,
NotifyUrl: notifyUrl,
ReturnUrl: returnUrl,
})
if err != nil { if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
amount := req.Amount
if !common.DisplayInCurrencyEnabled {
amount = amount / int(common.QuotaPerUnit)
}
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: amount, Amount: req.Amount,
Money: payMoney, Money: chargedMoney,
TradeNo: "A" + tradeNo, TradeNo: referenceId,
CreateTime: time.Now().Unix(), CreateTime: time.Now().Unix(),
Status: "pending", Status: common.TopUpStatusPending,
} }
err = topUp.Insert() err = topUp.Insert()
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
return return
} }
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri}) c.JSON(200, gin.H{
} "message": "success",
"data": gin.H{
// tradeNo lock "payLink": payLink,
var orderLocks sync.Map },
var createLock sync.Mutex })
// LockOrder 尝试对给定订单号加锁
func LockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if !ok {
createLock.Lock()
defer createLock.Unlock()
lock, ok = orderLocks.Load(tradeNo)
if !ok {
lock = new(sync.Mutex)
orderLocks.Store(tradeNo, lock)
}
}
lock.(*sync.Mutex).Lock()
}
// UnlockOrder 释放给定订单号的锁
func UnlockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if ok {
lock.(*sync.Mutex).Unlock()
}
}
func EpayNotify(c *gin.Context) {
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
client := GetEpayClient()
if client == nil {
log.Println("易支付回调失败 未找到配置信息")
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
return
}
}
verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus {
_, err := c.Writer.Write([]byte("success"))
if err != nil {
log.Println("易支付回调写入失败")
}
} else {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
}
log.Println("易支付回调签名验证失败")
return
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp == nil {
log.Printf("易支付回调未找到订单: %v", verifyInfo)
return
}
if topUp.Status == "pending" {
topUp.Status = "success"
err := topUp.Update()
if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp)
return
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
}
} }
func RequestAmount(c *gin.Context) { func RequestAmount(c *gin.Context) {
@@ -226,17 +142,23 @@ func RequestAmount(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return return
} }
if !common.PaymentEnabled {
if req.Amount < getMinTopup() { c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())}) return
}
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
payMoney := getPayMoney(float64(req.Amount), *user) payMoney := GetPayAmount(float64(req.Amount))
if payMoney <= 0.01 { chargedMoney := GetChargedAmount(float64(req.Amount), *user)
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) c.JSON(200, gin.H{
return "message": "success",
} "data": gin.H{
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)}) "payAmount": strconv.FormatFloat(payMoney, 'f', 2, 64),
"chargedAmount": strconv.FormatFloat(chargedMoney, 'f', 2, 64),
},
})
} }

View File

@@ -66,6 +66,7 @@ func setupLogin(user *model.User, c *gin.Context) {
session.Set("username", user.Username) session.Set("username", user.Username)
session.Set("role", user.Role) session.Set("role", user.Role)
session.Set("status", user.Status) session.Set("status", user.Status)
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
err := session.Save() err := session.Save()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -517,7 +518,7 @@ func UpdateSelf(c *gin.Context) {
return return
} }
func DeleteUser(c *gin.Context) { func HardDeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -526,7 +527,7 @@ func DeleteUser(c *gin.Context) {
}) })
return return
} }
originUser, err := model.GetUserById(id, false) originUser, err := model.GetUserByIdUnscoped(id, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -550,9 +551,23 @@ func DeleteUser(c *gin.Context) {
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
} }
func DeleteSelf(c *gin.Context) { func DeleteSelf(c *gin.Context) {
if !common.UserSelfDeletionEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "当前设置不允许用户自我删除账号",
})
return
}
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
@@ -791,11 +806,11 @@ type topUpRequest struct {
Key string `json:"key"` Key string `json:"key"`
} }
var lock = sync.Mutex{} var topUpLock = sync.Mutex{}
func TopUp(c *gin.Context) { func TopUp(c *gin.Context) {
lock.Lock() topUpLock.Lock()
defer lock.Unlock() defer topUpLock.Unlock()
req := topUpRequest{} req := topUpRequest{}
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {

View File

@@ -2,18 +2,17 @@ version: '3.4'
services: services:
new-api: new-api:
image: calciumion/new-api:latest image: pengzhile/new-api:latest
# build: .
container_name: new-api container_name: new-api
restart: always restart: always
command: --log-dir /app/logs command: --log-dir /app/logs
ports: ports:
- "3000:3000" - "3000:3000"
volumes: volumes:
- ./data:/data - ./data/new-api:/data
- ./logs:/app/logs - ./logs:/app/logs
environment: environment:
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库 - SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串 - SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
@@ -23,13 +22,22 @@ services:
depends_on: depends_on:
- redis - redis
healthcheck: - db
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
timeout: 10s
retries: 3
redis: redis:
image: redis:latest image: redis:latest
container_name: redis container_name: redis
restart: always restart: always
db:
image: mysql:8.2.0
container_name: mysql
restart: always
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: newapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: new-api # 自动创建数据库

View File

@@ -1,13 +1,34 @@
package dto package dto
type TextToSpeechRequest struct { type AudioRequest struct {
Model string `json:"model" binding:"required"` Model string `json:"model"`
Input string `json:"input" binding:"required"` Input string `json:"input"`
Voice string `json:"voice" binding:"required"` Voice string `json:"voice"`
Speed float64 `json:"speed"` Speed float64 `json:"speed,omitempty"`
ResponseFormat string `json:"response_format"` ResponseFormat string `json:"response_format,omitempty"`
} }
type AudioResponse struct { type AudioResponse struct {
Text string `json:"text"` Text string `json:"text"`
} }
type WhisperVerboseJSONResponse struct {
Task string `json:"task,omitempty"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
}
type Segment struct {
Id int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogprob float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
}

View File

@@ -12,9 +12,11 @@ type ImageRequest struct {
} }
type ImageResponse struct { type ImageResponse struct {
Created int `json:"created"` Data []ImageData `json:"data"`
Data []struct { Created int64 `json:"created"`
Url string `json:"url"` }
B64Json string `json:"b64_json"` type ImageData struct {
} Url string `json:"url"`
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
} }

View File

@@ -33,6 +33,12 @@ type MidjourneyResponse struct {
Result string `json:"result"` Result string `json:"result"`
} }
type MidjourneyUploadResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result []string `json:"result"`
}
type MidjourneyResponseWithStatusCode struct { type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"` StatusCode int `json:"statusCode"`
Response MidjourneyResponse Response MidjourneyResponse

View File

@@ -2,35 +2,37 @@ package dto
import "encoding/json" import "encoding/json"
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"` BestOf int `json:"best_of,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` Echo bool `json:"echo,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
TopP float64 `json:"top_p,omitempty"` Suffix string `json:"suffix,omitempty"`
TopK int `json:"top_k,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
Stop any `json:"stop,omitempty"` Temperature float64 `json:"temperature,omitempty"`
N int `json:"n,omitempty"` TopP float64 `json:"top_p,omitempty"`
Input any `json:"input,omitempty"` TopK int `json:"top_k,omitempty"`
Instruction string `json:"instruction,omitempty"` Stop any `json:"stop,omitempty"`
Size string `json:"size,omitempty"` N int `json:"n,omitempty"`
Functions any `json:"functions,omitempty"` Input any `json:"input,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` Instruction string `json:"instruction,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` Size string `json:"size,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Functions any `json:"functions,omitempty"`
Seed float64 `json:"seed,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
Tools any `json:"tools,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ResponseFormat any `json:"response_format,omitempty"`
User string `json:"user,omitempty"` Seed float64 `json:"seed,omitempty"`
LogProbs bool `json:"logprobs,omitempty"` Tools []ToolCall `json:"tools,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogitBias any `json:"logit_bias,omitempty"`
LogProbs any `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
ParallelToolCalls bool `json:"parallel_Tool_Calls,omitempty"`
} }
type OpenAITools struct { type OpenAITools struct {
@@ -103,6 +105,11 @@ func (m Message) StringContent() string {
return string(m.Content) return string(m.Content)
} }
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
}
func (m Message) IsStringContent() bool { func (m Message) IsStringContent() bool {
var stringContent string var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil { if err := json.Unmarshal(m.Content, &stringContent); err == nil {
@@ -142,7 +149,7 @@ func (m Message) ParseContent() []MediaMessage {
if ok { if ok {
subObj["detail"] = detail.(string) subObj["detail"] = detail.(string)
} else { } else {
subObj["detail"] = "auto" subObj["detail"] = "high"
} }
contentList = append(contentList, MediaMessage{ contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL, Type: ContentTypeImageURL,
@@ -151,7 +158,16 @@ func (m Message) ParseContent() []MediaMessage {
Detail: subObj["detail"].(string), Detail: subObj["detail"].(string),
}, },
}) })
} else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: "high",
},
})
} }
} }
} }
return contentList return contentList

View File

@@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
return c.Content == nil && len(c.ToolCalls) == 0
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
c.Content = &s c.Content = &s
} }
@@ -90,9 +86,11 @@ type ToolCall struct {
} }
type FunctionCall struct { type FunctionCall struct {
Name string `json:"name,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
// call function with arguments in JSON format // call function with arguments in JSON format
Arguments string `json:"arguments,omitempty"` Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments,omitempty"`
} }
type ChatCompletionsStreamResponse struct { type ChatCompletionsStreamResponse struct {
@@ -105,6 +103,17 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"` Usage *Usage `json:"usage"`
} }
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil {
return ""
}
return *c.SystemFingerprint
}
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
c.SystemFingerprint = &s
}
type ChatCompletionsStreamResponseSimple struct { type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"` Usage *Usage `json:"usage"`

4
go.mod
View File

@@ -4,7 +4,6 @@ module one-api
go 1.18 go 1.18
require ( require (
github.com/Calcium-Ion/go-epay v0.0.2
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/credentials v1.17.11
@@ -24,6 +23,7 @@ require (
github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0 github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible github.com/shirou/gopsutil v3.21.11+incompatible
github.com/stripe/stripe-go/v76 v76.21.0
golang.org/x/crypto v0.21.0 golang.org/x/crypto v0.21.0
golang.org/x/image v0.15.0 golang.org/x/image v0.15.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
@@ -38,6 +38,7 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
@@ -64,7 +65,6 @@ require (
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect

20
go.sum
View File

@@ -1,5 +1,3 @@
github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
@@ -18,6 +16,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@@ -32,6 +32,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
@@ -81,6 +83,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -124,6 +128,8 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/linux-do/tiktoken-go v0.7.0 h1:Kcm/miJ5gp77srtF8GQWnfq7W9kTaXEuHZg/g9IVEu8=
github.com/linux-do/tiktoken-go v0.7.0/go.mod h1:9Vkdtp0ngi4USmrdSx984iuIQ5IMr0hnUdz4jZZTJb8=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@@ -131,8 +137,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -148,6 +152,8 @@ github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNc
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -171,6 +177,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stripe/stripe-go/v76 v76.21.0 h1:O3GHImHS4oUI3qWMOClHN3zAQF5/oswS/NB7leV1fsU=
github.com/stripe/stripe-go/v76 v76.21.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@@ -196,16 +204,20 @@ golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSO
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

10
main.go
View File

@@ -3,12 +3,14 @@ package main
import ( import (
"embed" "embed"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie" "github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/controller" "one-api/controller"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
@@ -53,6 +55,8 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error()) common.FatalLog("failed to initialize Redis: " + err.Error())
} }
// Initialize constants
constant.InitEnv()
// Initialize options // Initialize options
model.InitOptionMap() model.InitOptionMap()
if common.RedisEnabled { if common.RedisEnabled {
@@ -89,11 +93,11 @@ func main() {
} }
go controller.AutomaticallyTestChannels(frequency) go controller.AutomaticallyTestChannels(frequency)
} }
if common.IsMasterNode { if common.IsMasterNode && constant.UpdateTask {
common.SafeGoroutine(func() { gopool.Go(func() {
controller.UpdateMidjourneyTaskBulk() controller.UpdateMidjourneyTaskBulk()
}) })
common.SafeGoroutine(func() { gopool.Go(func() {
controller.UpdateTaskBulk() controller.UpdateTaskBulk()
}) })
} }

View File

@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv"
"strings" "strings"
) )
@@ -15,6 +16,8 @@ func authHelper(c *gin.Context, minRole int) {
role := session.Get("role") role := session.Get("role")
id := session.Get("id") id := session.Get("id")
status := session.Get("status") status := session.Get("status")
linuxDoEnable := session.Get("linuxdo_enable")
useAccessToken := false
if username == nil { if username == nil {
// Check access token // Check access token
accessToken := c.Request.Header.Get("Authorization") accessToken := c.Request.Header.Get("Authorization")
@@ -33,6 +36,8 @@ func authHelper(c *gin.Context, minRole int) {
role = user.Role role = user.Role
id = user.Id id = user.Id
status = user.Status status = user.Status
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
useAccessToken = true
} else { } else {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -42,6 +47,36 @@ func authHelper(c *gin.Context, minRole int) {
return return
} }
} }
if !useAccessToken {
// get header New-Api-User
apiUserIdStr := c.Request.Header.Get("New-Api-User")
if apiUserIdStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,请刷新页面或清空缓存后重试",
})
c.Abort()
return
}
apiUserId, err := strconv.Atoi(apiUserIdStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,登录信息无效,请重新登录",
})
c.Abort()
return
}
if id != apiUserId {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,与登录用户不匹配,请重新登录",
})
c.Abort()
return
}
}
if status.(int) == common.UserStatusDisabled { if status.(int) == common.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -50,6 +85,14 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort() c.Abort()
return return
} }
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户 LINUX DO 信任等级不足",
})
c.Abort()
return
}
if role.(int) < minRole { if role.(int) < minRole {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -110,6 +153,12 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0] key = parts[0]
} }
token, err := model.ValidateUserToken(key) token, err := model.ValidateUserToken(key)
if token != nil {
id := c.GetInt("id")
if id == 0 {
c.Set("id", token.UserId)
}
}
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return return
@@ -123,6 +172,15 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return return
} }
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !linuxDoEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
return
}
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)

View File

@@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -25,6 +26,10 @@ func Distribute() func(c *gin.Context) {
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("specific_channel_id") channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c) modelRequest, shouldSelectChannel, err := getModelRequest(c)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup, _ := model.CacheGetUserGroup(userId) userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup) c.Set("group", userGroup)
if ok { if ok {
@@ -141,7 +146,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
} }
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, err return nil, false, errors.New("无效的请求, " + err.Error())
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
@@ -154,18 +159,22 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
} }
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
modelRequest.Model = "dall-e"
}
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" { relayMode := relayconstant.RelayModeAudioSpeech
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1" modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
} else { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
modelRequest.Model = "whisper-1" modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
} modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranslation
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranscription
} }
c.Set("relay_mode", relayMode)
} }
return &modelRequest, shouldSelectChannel, nil return &modelRequest, shouldSelectChannel, nil
} }
@@ -175,19 +184,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
if channel == nil { if channel == nil {
return return
} }
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type) c.Set("channel_type", channel.Type)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization) c.Set("channel_organization", *channel.OpenAIOrganization)
} }
c.Set("auto_ban", ban) c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping()) c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))

View File

@@ -1,11 +1,13 @@
package middleware package middleware
import ( import (
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
) )
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
userId := c.GetInt("id")
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"error": gin.H{ "error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
@@ -13,7 +15,7 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
}, },
}) })
c.Abort() c.Abort()
common.LogError(c.Request.Context(), message) common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
} }
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {

View File

@@ -205,6 +205,30 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err return userEnabled, err
} }
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
if !common.RedisEnabled {
return IsLinuxDoEnabled(userId)
}
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if linuxDoEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set linuxdo enabled error: " + err.Error())
}
return linuxDoEnabled, err
}
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex var channelSyncLock sync.RWMutex
@@ -269,6 +293,8 @@ func SyncChannelCache(frequency int) {
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") { if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*" model = "gpt-4-gizmo-*"
} else if strings.HasPrefix(model, "g-") {
model = "g-*"
} }
// if memory cache is disabled, get channel directly from database // if memory cache is disabled, get channel directly from database

View File

@@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
channel.OtherInfo = string(otherInfoBytes) channel.OtherInfo = string(otherInfoBytes)
} }
func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
}
return *channel.AutoBan == 1
}
func (channel *Channel) Save() error { func (channel *Channel) Save() error {
return DB.Save(channel).Error return DB.Save(channel).Error
} }
@@ -100,8 +107,8 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
var whereClause string var whereClause string
var args []interface{} var args []interface{}
if group != "" { if group != "" {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?" whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " = ? AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%") args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, group, "%"+model+"%")
} else { } else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?" whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%") args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")

View File

@@ -3,9 +3,11 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"strings" "strings"
"time"
) )
type Log struct { type Log struct {
@@ -87,13 +89,13 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
common.LogError(ctx, "failed to record log: "+err.Error()) common.LogError(ctx, "failed to record log: "+err.Error())
} }
if common.DataExportEnabled { if common.DataExportEnabled {
common.SafeGoroutine(func() { gopool.Go(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens) LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
}) })
} }
} }
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = DB tx = DB
@@ -101,7 +103,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = DB.Where("type = ?", logType) tx = DB.Where("type = ?", logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name like ?", modelName)
} }
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
@@ -118,11 +120,17 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if channel != 0 { if channel != 0 {
tx = tx.Where("channel_id = ?", channel) tx = tx.Where("channel_id = ?", channel)
} }
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err return logs, total, err
} }
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = DB.Where("user_id = ?", userId) tx = DB.Where("user_id = ?", userId)
@@ -130,7 +138,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = DB.Where("user_id = ? and type = ?", userId, logType) tx = DB.Where("user_id = ? and type = ?", userId, logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name like ?", modelName)
} }
if tokenName != "" { if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName) tx = tx.Where("token_name = ?", tokenName)
@@ -141,7 +149,14 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if endTimestamp != 0 { if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
return logs, total, err
for i := range logs { for i := range logs {
var otherMap map[string]interface{} var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other) otherMap = common.StrToMap(logs[i].Other)
@@ -151,7 +166,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
} }
logs[i].Other = common.MapToJsonStr(otherMap) logs[i].Other = common.MapToJsonStr(otherMap)
} }
return logs, err return logs, total, err
} }
func SearchAllLogs(keyword string) (logs []*Log, err error) { func SearchAllLogs(keyword string) (logs []*Log, err error) {
@@ -171,12 +186,18 @@ type Stat struct {
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") tx := DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询
rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
} }
if tokenName != "" { if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName) tx = tx.Where("token_name = ?", tokenName)
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
} }
if startTimestamp != 0 { if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp) tx = tx.Where("created_at >= ?", startTimestamp)
@@ -186,11 +207,23 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name = ?", modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName)
} }
if channel != 0 { if channel != 0 {
tx = tx.Where("channel_id = ?", channel) tx = tx.Where("channel_id = ?", channel)
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
} }
tx.Where("type = ?", LogTypeConsume).Scan(&stat)
tx = tx.Where("type = ?", LogTypeConsume)
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
// 只统计最近60秒的rpm和tpm
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
// 执行查询
tx.Scan(&stat)
rpmTpmQuery.Scan(&stat)
return stat return stat
} }

View File

@@ -31,10 +31,12 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
@@ -60,17 +62,19 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = "" common.OptionMap["ServerAddress"] = ""
common.OptionMap["WorkerUrl"] = constant.WorkerUrl common.OptionMap["OutProxyUrl"] = ""
common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey common.OptionMap["StripeApiSecret"] = common.StripeApiSecret
common.OptionMap["PayAddress"] = "" common.OptionMap["StripeWebhookSecret"] = common.StripeWebhookSecret
common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["StripePriceId"] = common.StripePriceId
common.OptionMap["EpayId"] = "" common.OptionMap["PaymentEnabled"] = strconv.FormatBool(common.PaymentEnabled)
common.OptionMap["EpayKey"] = "" common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64)
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64) common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = "" common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["LinuxDoClientId"] = ""
common.OptionMap["LinuxDoClientSecret"] = ""
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
common.OptionMap["TelegramBotToken"] = "" common.OptionMap["TelegramBotToken"] = ""
common.OptionMap["TelegramBotName"] = "" common.OptionMap["TelegramBotName"] = ""
common.OptionMap["WeChatServerAddress"] = "" common.OptionMap["WeChatServerAddress"] = ""
@@ -173,6 +177,8 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailVerificationEnabled = boolValue common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue common.GitHubOAuthEnabled = boolValue
case "LinuxDoOAuthEnabled":
common.LinuxDoOAuthEnabled = boolValue
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue common.WeChatAuthEnabled = boolValue
case "TelegramOAuthEnabled": case "TelegramOAuthEnabled":
@@ -181,6 +187,8 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileCheckEnabled = boolValue common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled": case "RegisterEnabled":
common.RegisterEnabled = boolValue common.RegisterEnabled = boolValue
case "UserSelfDeletionEnabled":
common.UserSelfDeletionEnabled = boolValue
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue common.EmailDomainRestrictionEnabled = boolValue
case "EmailAliasRestrictionEnabled": case "EmailAliasRestrictionEnabled":
@@ -240,29 +248,33 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPToken": case "SMTPToken":
common.SMTPToken = value common.SMTPToken = value
case "ServerAddress": case "ServerAddress":
constant.ServerAddress = value common.ServerAddress = value
case "WorkerUrl": case "OutProxyUrl":
constant.WorkerUrl = value common.OutProxyUrl = value
case "WorkerValidKey": case "StripeApiSecret":
constant.WorkerValidKey = value common.StripeApiSecret = value
case "PayAddress": case "StripeWebhookSecret":
constant.PayAddress = value common.StripeWebhookSecret = value
case "CustomCallbackAddress": case "StripePriceId":
constant.CustomCallbackAddress = value common.StripePriceId = value
case "EpayId": case "PaymentEnabled":
constant.EpayId = value common.PaymentEnabled, _ = strconv.ParseBool(value)
case "EpayKey": case "StripeUnitPrice":
constant.EpayKey = value common.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
case "Price":
constant.Price, _ = strconv.ParseFloat(value, 64)
case "MinTopUp": case "MinTopUp":
constant.MinTopUp, _ = strconv.Atoi(value) common.MinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio": case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value) err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId": case "GitHubClientId":
common.GitHubClientId = value common.GitHubClientId = value
case "GitHubClientSecret": case "GitHubClientSecret":
common.GitHubClientSecret = value common.GitHubClientSecret = value
case "LinuxDoClientId":
common.LinuxDoClientId = value
case "LinuxDoClientSecret":
common.LinuxDoClientSecret = value
case "LinuxDoMinLevel":
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
case "Footer": case "Footer":
common.Footer = value common.Footer = value
case "SystemName": case "SystemName":

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/constant"
"strconv" "strconv"
"strings" "strings"
) )
@@ -51,12 +50,12 @@ func ValidateUserToken(key string) (token *Token, err error) {
if token.Status == common.TokenStatusExhausted { if token.Status == common.TokenStatusExhausted {
keyPrefix := key[:3] keyPrefix := key[:3]
keySuffix := key[len(key)-3:] keySuffix := key[len(key)-3:]
return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
} else if token.Status == common.TokenStatusExpired { } else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期") return token, errors.New("该令牌已过期")
} }
if token.Status != common.TokenStatusEnabled { if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用") return token, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if !common.RedisEnabled { if !common.RedisEnabled {
@@ -66,7 +65,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysError("failed to update token status" + err.Error()) common.SysError("failed to update token status" + err.Error())
} }
} }
return nil, errors.New("该令牌已过期") return token, errors.New("该令牌已过期")
} }
if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled { if !common.RedisEnabled {
@@ -79,7 +78,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
} }
keyPrefix := key[:3] keyPrefix := key[:3]
keySuffix := key[len(key)-3:] keySuffix := key[len(key)-3:]
return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
} }
return token, nil return token, nil
} }
@@ -294,7 +293,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
prompt = "您的额度已用尽" prompt = "您的额度已用尽"
} }
if email != "" { if email != "" {
topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
err = common.SendEmail(prompt, email, err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil { if err != nil {

View File

@@ -1,13 +1,21 @@
package model package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
)
type TopUp struct { type TopUp struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id" gorm:"index"`
Amount int `json:"amount"` Amount int `json:"amount"`
Money float64 `json:"money"` Money float64 `json:"money"`
TradeNo string `json:"trade_no"` TradeNo string `json:"trade_no" gorm:"unique"`
CreateTime int64 `json:"create_time"` CreateTime int64 `json:"create_time"`
Status string `json:"status"` CompleteTime int64 `json:"complete_time"`
Status string `json:"status"`
} }
func (topUp *TopUp) Insert() error { func (topUp *TopUp) Insert() error {
@@ -41,3 +49,51 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
} }
return topUp return topUp
} }
func Recharge(referenceId string, customerId string) (err error) {
if referenceId == "" {
return errors.New("未提供支付单号")
}
var quota float64
topUp := &TopUp{}
refCol := "`trade_no`"
if common.UsingPostgreSQL {
refCol = `"trade_no"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
if err != nil {
return errors.New("充值订单不存在")
}
if topUp.Status != common.TopUpStatusPending {
return errors.New("充值订单状态错误")
}
topUp.CompleteTime = common.GetTimestamp()
topUp.Status = common.TopUpStatusSuccess
err = tx.Save(topUp).Error
if err != nil {
return err
}
quota = topUp.Money * common.QuotaPerUnit
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
if err != nil {
return err
}
return nil
})
if err != nil {
return errors.New("充值失败," + err.Error())
}
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%d", common.LogQuotaF(quota), topUp.Amount))
return nil
}

View File

@@ -22,6 +22,8 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"` Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"` GitHubId string `json:"github_id" gorm:"column:github_id;index"`
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
@@ -35,6 +37,7 @@ type User struct {
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度 AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度 AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
StripeCustomer string `json:"stripe_customer" gorm:"column:stripe_customer;index"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
} }
@@ -64,7 +67,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
func GetMaxUserId() int { func GetMaxUserId() int {
var user User var user User
DB.Last(&user) DB.Unscoped().Last(&user)
return user.Id return user.Id
} }
@@ -119,6 +122,20 @@ func GetUserById(id int, selectAll bool) (*User, error) {
return &user, err return &user, err
} }
func GetUserByIdUnscoped(id int, selectAll bool) (*User, error) {
if id == 0 {
return nil, errors.New("id 为空!")
}
user := User{Id: id}
var err error = nil
if selectAll {
err = DB.Unscoped().First(&user, "id = ?", id).Error
} else {
err = DB.Unscoped().Omit("password").First(&user, "id = ?", id).Error
}
return &user, err
}
func GetUserIdByAffCode(affCode string) (int, error) { func GetUserIdByAffCode(affCode string) (int, error) {
if affCode == "" { if affCode == "" {
return 0, errors.New("affCode 为空!") return 0, errors.New("affCode 为空!")
@@ -331,6 +348,14 @@ func (user *User) FillUserByGitHubId() error {
return nil return nil
} }
func (user *User) FillUserByLinuxDoId() error {
if user.LinuxDoId == "" {
return errors.New("LINUX DO id 为空!")
}
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error { func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" { if user.WeChatId == "" {
return errors.New("WeChat id 为空!") return errors.New("WeChat id 为空!")
@@ -370,6 +395,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
} }
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
return DB.Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool { func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
} }
@@ -415,6 +444,18 @@ func IsUserEnabled(userId int) (bool, error) {
return user.Status == common.UserStatusEnabled, nil return user.Status == common.UserStatusEnabled, nil
} }
func IsLinuxDoEnabled(userId int) (bool, error) {
if userId == 0 {
return false, errors.New("user id is empty")
}
var user User
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
if err != nil {
return false, err
}
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
}
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
if token == "" { if token == "" {
return nil return nil

View File

@@ -2,6 +2,7 @@ package model
import ( import (
"errors" "errors"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"sync" "sync"
@@ -28,12 +29,12 @@ func init() {
} }
func InitBatchUpdater() { func InitBatchUpdater() {
go func() { gopool.Go(func() {
for { for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
batchUpdate() batchUpdate()
} }
}() })
} }
func addNewRecord(type_ int, id int, value int) { func addNewRecord(type_ int, id int, value int) {

View File

@@ -10,12 +10,13 @@ import (
type Adaptor interface { type Adaptor interface {
// Init IsStream bool // Init IsStream bool
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) Init(info *relaycommon.RelayInfo)
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
GetRequestURL(info *relaycommon.RelayInfo) (string, error) GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string GetModelList() []string

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
) )
@@ -15,17 +16,18 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl) var fullRequestURL string
if info.RelayMode == constant.RelayModeEmbeddings { switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
} }
return fullRequestURL, nil return fullRequestURL, nil
} }
@@ -42,22 +44,32 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch info.RelayMode {
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request) baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, nil
default: default:
baiduRequest := requestOpenAI2Ali(*request) aliReq := requestOpenAI2Ali(*request)
return baiduRequest, nil return aliReq, nil
} }
} }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
aliRequest := oaiImage2Ali(request)
return aliRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
@@ -65,14 +77,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { switch info.RelayMode {
err, usage = aliStreamHandler(c, resp) case constant.RelayModeImagesGenerations:
} else { err, usage = aliImageHandler(c, resp, info)
switch info.RelayMode { case constant.RelayModeEmbeddings:
case constant.RelayModeEmbeddings: err, usage = aliEmbeddingHandler(c, resp)
err, usage = aliEmbeddingHandler(c, resp) default:
default: if info.IsStream {
err, usage = aliHandler(c, resp) err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
} }
return return

View File

@@ -60,13 +60,40 @@ type AliUsage struct {
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
} }
type AliOutput struct { type TaskResult struct {
Text string `json:"text"` B64Image string `json:"b64_image,omitempty"`
FinishReason string `json:"finish_reason"` Url string `json:"url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
} }
type AliChatResponse struct { type AliOutput struct {
TaskId string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
Message string `json:"message,omitempty"`
Code string `json:"code,omitempty"`
Results []TaskResult `json:"results,omitempty"`
}
type AliResponse struct {
Output AliOutput `json:"output"` Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"` Usage AliUsage `json:"usage"`
AliError AliError
} }
type AliImageRequest struct {
Model string `json:"model"`
Input struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
} `json:"input"`
Parameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}

177
relay/channel/ali/image.go Normal file
View File

@@ -0,0 +1,177 @@
package ali
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
var imageRequest AliImageRequest
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = request.N
imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest
}
func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("/api/v1/tasks/%s", taskID)
var aliResponse AliResponse
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+key)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
common.SysError("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
var response AliResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
common.SysError("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
return &response, nil, responseBody
}
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) {
waitSeconds := 3
step := 0
maxStep := 20
var taskResponse AliResponse
var responseBody []byte
for {
step++
rsp, err, body := updateTask(info, taskID, key)
responseBody = body
if err != nil {
return &taskResponse, responseBody, err
}
if rsp.Output.TaskStatus == "" {
return &taskResponse, responseBody, nil
}
switch rsp.Output.TaskStatus {
case "FAILED":
fallthrough
case "CANCELED":
fallthrough
case "SUCCEEDED":
fallthrough
case "UNKNOWN":
return rsp, responseBody, nil
}
if step >= maxStep {
break
}
time.Sleep(time.Duration(waitSeconds) * time.Second)
}
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
}
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
imageResponse := dto.ImageResponse{
Created: info.StartTime.Unix(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := common.GetImageFromUrl(data.Url)
if err != nil {
common.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
} else {
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return &imageResponse
}
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
}
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey)
if err != nil {
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, nil
}

View File

@@ -16,34 +16,13 @@ import (
const EnableSearchModelSuffix = "-internet" const EnableSearchModelSuffix = "-internet"
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest { func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]AliMessage, 0, len(request.Messages)) if request.TopP >= 1 {
//prompt := "" request.TopP = 0.999
for i := 0; i < len(request.Messages); i++ { } else if request.TopP <= 0 {
message := request.Messages[i] request.TopP = 0.001
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
//Prompt: prompt,
Messages: messages,
},
Parameters: AliParameters{
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
EnableSearch: enableSearch,
},
} }
return &request
} }
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest { func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
@@ -110,7 +89,7 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe
return &openAIEmbeddingResponse return &openAIEmbeddingResponse
} }
func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse { func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
content, _ := json.Marshal(response.Output.Text) content, _ := json.Marshal(response.Output.Text)
choice := dto.OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: 0, Index: 0,
@@ -134,7 +113,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
return &fullTextResponse return &fullTextResponse
} }
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse { func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(aliResponse.Output.Text) choice.Delta.SetContentString(aliResponse.Output.Text)
if aliResponse.Output.FinishReason != "null" { if aliResponse.Output.FinishReason != "null" {
@@ -154,18 +133,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletions
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage dto.Usage var usage dto.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string) dataChan := make(chan string)
stopChan := make(chan bool) stopChan := make(chan bool)
go func() { go func() {
@@ -187,7 +155,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
var aliResponse AliChatResponse var aliResponse AliResponse
err := json.Unmarshal([]byte(data), &aliResponse) err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
@@ -221,7 +189,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
} }
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var aliResponse AliChatResponse var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -7,14 +7,19 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/relay/common" "one-api/relay/common"
"one-api/relay/constant"
"one-api/service" "one-api/service"
) )
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) { func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
req.Header.Set("Accept", c.Request.Header.Get("Accept")) // multipart/form-data
if info.IsStream && c.Request.Header.Get("Accept") == "" { } else {
req.Header.Set("Accept", "text/event-stream") req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
} }
} }
@@ -38,6 +43,29 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
return resp, nil return resp, nil
} }
func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
// set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
err = a.SetupRequestHeader(c, req, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := service.GetHttpClient().Do(req) resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {

View File

@@ -20,12 +20,17 @@ type Adaptor struct {
RequestMode int RequestMode int
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") { if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage a.RequestMode = RequestModeMessage
} else { } else {
@@ -41,7 +46,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -222,9 +222,11 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
} }
service.Done(c) service.Done(c)
err = resp.Body.Close() if resp != nil {
if err != nil { err = resp.Body.Close()
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
} }
return nil, &usage return nil, &usage
} }

View File

@@ -16,12 +16,17 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
@@ -99,11 +104,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch info.RelayMode {
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request) baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, nil

View File

@@ -21,12 +21,17 @@ type Adaptor struct {
RequestMode int RequestMode int
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") { if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage a.RequestMode = RequestModeMessage
} else { } else {
@@ -53,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -5,11 +5,18 @@ type ClaudeMetadata struct {
} }
type ClaudeMediaMessage struct { type ClaudeMediaMessage struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"` Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"` StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content string `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
} }
type ClaudeMessageSource struct { type ClaudeMessageSource struct {
@@ -23,6 +30,18 @@ type ClaudeMessage struct {
Content any `json:"content"` Content any `json:"content"`
} }
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema InputSchema `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties any `json:"properties,omitempty"`
Required any `json:"required,omitempty"`
}
type ClaudeRequest struct { type ClaudeRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@@ -35,7 +54,9 @@ type ClaudeRequest struct {
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"` //ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
} }
type ClaudeError struct { type ClaudeError struct {
@@ -44,24 +65,20 @@ type ClaudeError struct {
} }
type ClaudeResponse struct { type ClaudeResponse struct {
Id string `json:"id"` Id string `json:"id"`
Type string `json:"type"` Type string `json:"type"`
Content []ClaudeMediaMessage `json:"content"` Content []ClaudeMediaMessage `json:"content"`
Completion string `json:"completion"` Completion string `json:"completion"`
StopReason string `json:"stop_reason"` StopReason string `json:"stop_reason"`
Model string `json:"model"` Model string `json:"model"`
Error ClaudeError `json:"error"` Error ClaudeError `json:"error"`
Usage ClaudeUsage `json:"usage"` Usage ClaudeUsage `json:"usage"`
Index int `json:"index"` // stream only Index int `json:"index"` // stream only
Delta *ClaudeMediaMessage `json:"delta"` // stream only ContentBlock *ClaudeMediaMessage `json:"content_block"`
Message *ClaudeResponse `json:"message"` // stream only: message_start Delta *ClaudeMediaMessage `json:"delta"` // stream only
Message *ClaudeResponse `json:"message"` // stream only: message_start
} }
//type ClaudeResponseChoice struct {
// Index int `json:"index"`
// Type string `json:"type"`
//}
type ClaudeUsage struct { type ClaudeUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`

View File

@@ -8,12 +8,10 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"time"
) )
func stopReasonClaude2OpenAI(reason string) string { func stopReasonClaude2OpenAI(reason string) string {
@@ -30,6 +28,7 @@ func stopReasonClaude2OpenAI(reason string) string {
} }
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{ claudeRequest := ClaudeRequest{
Model: textRequest.Model, Model: textRequest.Model,
Prompt: "", Prompt: "",
@@ -60,6 +59,22 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
} }
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
claudeTools := make([]Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTools = append(claudeTools, Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: InputSchema{
Type: params["type"].(string),
Properties: params["properties"],
Required: params["required"],
},
})
}
}
claudeRequest := ClaudeRequest{ claudeRequest := ClaudeRequest{
Model: textRequest.Model, Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens, MaxTokens: textRequest.MaxTokens,
@@ -68,6 +83,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
TopP: textRequest.TopP, TopP: textRequest.TopP,
TopK: textRequest.TopK, TopK: textRequest.TopK,
Stream: textRequest.Stream, Stream: textRequest.Stream,
Tools: claudeTools,
} }
if claudeRequest.MaxTokens == 0 { if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096 claudeRequest.MaxTokens = 4096
@@ -154,11 +170,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
// 判断是否是url // 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据 // 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url) mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data claudeMediaMessage.Source.Data = data
} else { } else {
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url) _, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -184,6 +200,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCall, 0)
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion) choice.Delta.SetContentString(claudeResponse.Completion)
@@ -199,10 +216,33 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
choice.Delta.SetContentString("") choice.Delta.SetContentString("")
choice.Delta.Role = "assistant" choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" { } else if claudeResponse.Type == "content_block_start" {
return nil, nil if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCall{
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionCall{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
} else {
return nil, nil
}
} else if claudeResponse.Type == "content_block_delta" { } else if claudeResponse.Type == "content_block_delta" {
choice.Index = claudeResponse.Index if claudeResponse.Delta != nil {
choice.Delta.SetContentString(claudeResponse.Delta.Text) choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, dto.ToolCall{
Function: dto.FunctionCall{
Arguments: claudeResponse.Delta.PartialJson,
},
})
}
}
} else if claudeResponse.Type == "message_delta" { } else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason) finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" { if finishReason != "null" {
@@ -218,6 +258,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeUsage == nil { if claudeUsage == nil {
claudeUsage = &ClaudeUsage{} claudeUsage = &ClaudeUsage{}
} }
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
}
response.Choices = append(response.Choices, choice) response.Choices = append(response.Choices, choice)
return &response, claudeUsage return &response, claudeUsage
@@ -230,6 +274,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
} }
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
tools := make([]dto.ToolCall, 0)
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
@@ -244,20 +293,32 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
choices = append(choices, choice) choices = append(choices, choice)
} else { } else {
fullTextResponse.Id = claudeResponse.Id fullTextResponse.Id = claudeResponse.Id
for i, message := range claudeResponse.Content { for _, message := range claudeResponse.Content {
content, _ := json.Marshal(message.Text) if message.Type == "tool_use" {
choice := dto.OpenAITextResponseChoice{ args, _ := json.Marshal(message.Input)
Index: i, tools = append(tools, dto.ToolCall{
Message: dto.Message{ ID: message.Id,
Role: "assistant", Type: "function", // compatible with other OpenAI derivative applications
Content: content, Function: dto.FunctionCall{
}, Name: message.Name,
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), Arguments: string(args),
},
})
} }
choices = append(choices, choice)
} }
} }
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choice.SetStringContent(responseText)
if len(tools) > 0 {
choice.Message.ToolCalls = tools
}
choices = append(choices, choice)
fullTextResponse.Choices = choices fullTextResponse.Choices = choices
return &fullTextResponse return &fullTextResponse
} }
@@ -269,89 +330,59 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
responseText := "" responseText := ""
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 { service.SetEventStreamHeaders(c)
return 0, nil, nil
for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
continue
} }
if i := strings.Index(string(data), "\n"); i >= 0 { data = strings.TrimPrefix(data, "data:")
return i + 1, data[0:i], nil data = strings.TrimSpace(data)
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
} }
if atEOF {
return len(data), data, nil response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
continue
} }
return 0, nil, nil if requestMode == RequestModeCompletion {
}) responseText += claudeResponse.Completion
dataChan := make(chan string, 5) responseId = response.Id
stopChan := make(chan bool, 2) } else {
go func() { if claudeResponse.Type == "message_start" {
for scanner.Scan() { // message_start, 获取usage
data := scanner.Text() responseId = claudeResponse.Message.Id
if !strings.HasPrefix(data, "data: ") { info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
} else {
continue continue
} }
data = strings.TrimPrefix(data, "data: ")
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
} }
stopChan <- true //response.Id = responseId
}() response.Id = responseId
isFirst := true response.Created = createdTime
service.SetEventStreamHeaders(c) response.Model = info.UpstreamModelName
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) err = service.ObjectData(c, response)
if response == nil { if err != nil {
return true common.LogError(c, "send_stream_response_failed: "+err.Error())
}
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else {
return true
}
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
}
return true
case <-stopChan:
return false
} }
}) }
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
@@ -370,10 +401,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
} }
} }
service.Done(c) service.Done(c)
err := resp.Body.Close() resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage return nil, usage
} }

View File

@@ -1,6 +1,7 @@
package cloudflare package cloudflare
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -15,10 +16,7 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -38,11 +36,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch info.RelayMode {
case constant.RelayModeCompletions: case constant.RelayModeCompletions:
return convertCf2CompletionsRequest(*request), nil return convertCf2CompletionsRequest(*request), nil
default: default:
@@ -58,11 +56,42 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
return request, nil return request, nil
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
// 添加文件字段
file, _, err := c.Request.FormFile("file")
if err != nil {
return nil, errors.New("file is required")
}
defer file.Close()
// 打开临时文件用于保存上传的文件内容
requestBody := &bytes.Buffer{}
// 将上传的文件内容复制到临时文件
if _, err := io.Copy(requestBody, file); err != nil {
return nil, err
}
return requestBody, nil
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { switch info.RelayMode {
err, usage = cfStreamHandler(c, resp, info) case constant.RelayModeEmbeddings:
} else { fallthrough
err, usage = cfHandler(c, resp, info) case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = cfStreamHandler(c, resp, info)
} else {
err, usage = cfHandler(c, resp, info)
}
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err, usage = cfSTTHandler(c, resp, info)
} }
return return
} }

View File

@@ -1,6 +1,7 @@
package cloudflare package cloudflare
var ModelList = []string{ var ModelList = []string{
"@cf/meta/llama-3.1-8b-instruct",
"@cf/meta/llama-2-7b-chat-fp16", "@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8", "@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1", "@cf/mistral/mistral-7b-instruct-v0.1",

View File

@@ -11,3 +11,11 @@ type CfRequest struct {
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
} }
type CfAudioResponse struct {
Result CfSTTResult `json:"result"`
}
type CfSTTResult struct {
Text string `json:"text"`
}

View File

@@ -11,6 +11,7 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"time"
) )
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest { func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
@@ -30,6 +31,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
service.SetEventStreamHeaders(c) service.SetEventStreamHeaders(c)
id := service.GetResponseID(c) id := service.GetResponseID(c)
var responseText string var responseText string
isFirst := true
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
@@ -56,6 +58,10 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
response.Id = id response.Id = id
response.Model = info.UpstreamModelName response.Model = info.UpstreamModelName
err = service.ObjectData(c, response) err = service.ObjectData(c, response)
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
if err != nil { if err != nil {
common.LogError(c, "error_rendering_stream_response: "+err.Error()) common.LogError(c, "error_rendering_stream_response: "+err.Error())
} }
@@ -113,3 +119,38 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
_, _ = c.Writer.Write(jsonResponse) _, _ = c.Writer.Write(jsonResponse)
return nil, usage return nil, usage
} }
func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var cfResp CfAudioResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &cfResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
audioResp := &dto.AudioResponse{
Text: cfResp.Result.Text,
}
jsonResponse, err := json.Marshal(audioResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}

View File

@@ -1,6 +1,7 @@
package cohere package cohere
import ( import (
"errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
@@ -14,10 +15,17 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -34,7 +42,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil return requestOpenAI2Cohere(*request), nil
} }

View File

@@ -14,12 +14,17 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -32,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -53,7 +53,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n") choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
} else if constant.DifyDebug && difyResponse.Event == "node_started" { } else if constant.DifyDebug && difyResponse.Event == "node_started" {
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n") choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
} else if difyResponse.Event == "message" { } else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
choice.Delta.SetContentString(difyResponse.Answer) choice.Delta.SetContentString(difyResponse.Answer)
} }
response.Choices = append(response.Choices, choice) response.Choices = append(response.Choices, choice)

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
@@ -14,22 +15,23 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
} }
// 定义一个映射,存储模型名称和对应的版本 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
var modelVersionMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-ultra": "v1beta",
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1" // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName] version, beta := constant.GeminiModelMap[info.UpstreamModelName]
if !beta { if !beta {
if info.ApiVersion != "" { if info.ApiVersion != "" {
version = info.ApiVersion version = info.ApiVersion
@@ -40,7 +42,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
action := "generateContent" action := "generateContent"
if info.IsStream { if info.IsStream {
action = "streamGenerateContent" action = "streamGenerateContent?alt=sse"
} }
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
} }
@@ -51,7 +53,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -12,9 +12,15 @@ type GeminiInlineData struct {
Data string `json:"data"` Data string `json:"data"`
} }
type FunctionCall struct {
FunctionName string `json:"name"`
Arguments any `json:"args"`
}
type GeminiPart struct { type GeminiPart struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"` InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
} }
type GeminiChatContent struct { type GeminiChatContent struct {

View File

@@ -4,18 +4,14 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"time"
"github.com/gin-gonic/gin"
) )
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
@@ -46,7 +42,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
MaxOutputTokens: textRequest.MaxTokens, MaxOutputTokens: textRequest.MaxTokens,
}, },
} }
if textRequest.Functions != nil { if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
functions = append(functions, tool.Function)
}
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: functions,
},
}
} else if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{ geminiRequest.Tools = []GeminiChatTools{
{ {
FunctionDeclarations: textRequest.Functions, FunctionDeclarations: textRequest.Functions,
@@ -77,13 +83,28 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
if imageNum > GeminiVisionMaxImageNum { if imageNum > GeminiVisionMaxImageNum {
continue continue
} }
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) // 判断是否是url
parts = append(parts, GeminiPart{ if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
InlineData: &GeminiInlineData{ // 是url获取图片的类型和base64编码的数据
MimeType: mimeType, mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
Data: data, parts = append(parts, GeminiPart{
}, InlineData: &GeminiInlineData{
}) MimeType: mimeType,
Data: data,
},
})
} else {
_, format, base64String, err := common.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
continue
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
}
} }
} }
content.Parts = parts content.Parts = parts
@@ -126,6 +147,30 @@ func (g *GeminiChatResponse) GetResponseText() string {
return "" return ""
} }
func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
var toolCalls []dto.ToolCall
item := candidate.Content.Parts[0]
if item.FunctionCall == nil {
return toolCalls
}
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil {
//common.SysError("getToolCalls failed: " + err.Error())
return toolCalls
}
toolCall := dto.ToolCall{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionCall{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse { func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -144,8 +189,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
FinishReason: relaycommon.StopFinishReason, FinishReason: relaycommon.StopFinishReason,
} }
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
content, _ = json.Marshal(candidate.Content.Parts[0].Text) if candidate.Content.Parts[0].FunctionCall != nil {
choice.Message.Content = content choice.Message.ToolCalls = getToolCalls(&candidate)
} else {
choice.Message.SetStringContent(candidate.Content.Parts[0].Text)
}
} }
fullTextResponse.Choices = append(fullTextResponse.Choices, choice) fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
} }
@@ -154,8 +202,17 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse { func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(geminiResponse.GetResponseText()) //choice.Delta.SetContentString(geminiResponse.GetResponseText())
choice.FinishReason = &relaycommon.StopFinishReason if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
respFirst := geminiResponse.Candidates[0].Content.Parts[0]
if respFirst.FunctionCall != nil {
// function response
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
} else {
// text response
choice.Delta.SetContentString(respFirst.Text)
}
}
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = "gemini" response.Model = "gemini"
@@ -165,104 +222,60 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := "" responseText := ""
responseJson := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp() createAt := common.GetTimestamp()
var usage = &dto.Usage{} var usage = &dto.Usage{}
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
responseJson += data
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
}
stopChan <- true
}()
isFirst := true
service.SetEventStreamHeaders(c) service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { for scanner.Scan() {
select { data := scanner.Text()
case data := <-dataChan: info.SetFirstResponseTime()
if isFirst { data = strings.TrimSpace(data)
isFirst = false if !strings.HasPrefix(data, "data: ") {
info.FirstResponseTime = time.Now() continue
}
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(dummy.Content)
response := dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: info.UpstreamModelName,
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
return false
} }
}) data = strings.TrimPrefix(data, "data: ")
var geminiChatResponses []GeminiChatResponse data = strings.TrimSuffix(data, "\"")
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses) var geminiResponse GeminiChatResponse
if err != nil { err := json.Unmarshal([]byte(data), &geminiResponse)
log.Printf("cannot get gemini usage: %s", err.Error()) if err != nil {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) common.LogError(c, "error unmarshalling stream response: "+err.Error())
} else { continue
for _, response := range geminiChatResponses { }
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil {
continue
}
response.Id = id
response.Created = createAt
responseText += response.Choices[0].Delta.GetContentString()
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
}
err = service.ObjectData(c, response)
if err != nil {
common.LogError(c, err.Error())
} }
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
} }
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, relaycommon.StopFinishReason)
service.ObjectData(c, response)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response) err := service.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysError("send final response failed: " + err.Error())
} }
} }
service.Done(c) service.Done(c)
err = resp.Body.Close() resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
}
return nil, usage return nil, usage
} }

View File

@@ -15,10 +15,17 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -36,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil return request, nil
} }

View File

@@ -10,16 +10,22 @@ import (
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -36,11 +42,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch info.RelayMode {
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil return requestOpenAI2Embeddings(*request), nil
default: default:
@@ -58,11 +64,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string err, usage = openai.OaiStreamHandler(c, resp, info)
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
} else { } else {
if info.RelayMode == relayconstant.RelayModeEmbeddings { if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)

View File

@@ -3,14 +3,18 @@ package ollama
import "one-api/dto" import "one-api/dto"
type OllamaRequest struct { type OllamaRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"` Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"` Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"` Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"` Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
} }
type OllamaEmbeddingRequest struct { type OllamaEmbeddingRequest struct {
@@ -21,6 +25,3 @@ type OllamaEmbeddingRequest struct {
type OllamaEmbeddingResponse struct { type OllamaEmbeddingResponse struct {
Embedding []float64 `json:"embedding,omitempty"` Embedding []float64 `json:"embedding,omitempty"`
} }
//type OllamaOptions struct {
//}

View File

@@ -28,14 +28,18 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
Stop, _ = request.Stop.([]string) Stop, _ = request.Stop.([]string)
} }
return &OllamaRequest{ return &OllamaRequest{
Model: request.Model, Model: request.Model,
Messages: messages, Messages: messages,
Stream: request.Stream, Stream: request.Stream,
Temperature: request.Temperature, Temperature: request.Temperature,
Seed: request.Seed, Seed: request.Seed,
Topp: request.TopP, Topp: request.TopP,
TopK: request.TopK, TopK: request.TopK,
Stop: Stop, Stop: Stop,
Tools: request.Tools,
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
} }
} }

View File

@@ -1,10 +1,13 @@
package openai package openai
import ( import (
"bytes"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"mime/multipart"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
@@ -14,22 +17,16 @@ import (
"one-api/relay/channel/minimax" "one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot" "one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/relay/constant"
"strings" "strings"
) )
type Adaptor struct { type Adaptor struct {
ChannelType int ChannelType int
ResponseFormat string
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
return nil, nil
}
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.ChannelType = info.ChannelType a.ChannelType = info.ChannelType
} }
@@ -74,28 +71,84 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
if info.ChannelType != common.ChannelTypeOpenAI {
request.StreamOptions = nil
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
if info.RelayMode == constant.RelayModeAudioSpeech {
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshalling object: %w", err)
}
return bytes.NewReader(jsonData), nil
} else {
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 添加文件字段
file, header, err := c.Request.FormFile("file")
if err != nil {
return nil, errors.New("file is required")
}
defer file.Close()
part, err := writer.CreateFormFile("file", header.Filename)
if err != nil {
return nil, errors.New("create form file failed")
}
if _, err := io.Copy(part, file); err != nil {
return nil, errors.New("copy file failed")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return &requestBody, nil
}
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody) if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
return channel.DoFormRequest(a, c, info, requestBody)
} else {
return channel.DoApiRequest(a, c, info, requestBody)
}
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { switch info.RelayMode {
var responseText string case constant.RelayModeAudioSpeech:
var toolCount int err, usage = OpenaiTTSHandler(c, resp, info)
err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info) case constant.RelayModeAudioTranslation:
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { fallthrough
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) case constant.RelayModeAudioTranscription:
usage.CompletionTokens += toolCount * 7 err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

View File

@@ -1,20 +1,21 @@
package openai package openai
var ModelList = []string{ var ModelList = []string{
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-instruct", "gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0613", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
"text-moderation-latest", "text-moderation-stable", "text-moderation-latest", "text-moderation-stable",
"text-davinci-edit-001", "text-davinci-edit-001",
"davinci-002", "babbage-002", "davinci-002", "babbage-002",
"dall-e-3", "dall-e-2", "dall-e-3",
"whisper-1", "whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
} }

View File

@@ -4,6 +4,8 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
@@ -18,34 +20,36 @@ import (
"time" "time"
) )
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) { func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
//checkSensitive := constant.ShouldCheckCompletionSensitive() containStreamUsage := false
var responseId string
var createAt int64 = 0
var systemFingerprint string
model := info.UpstreamModelName
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
var usage dto.Usage var usage = &dto.Usage{}
var streamItems []string // store stream items
toolCount := 0 toolCount := 0
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil service.SetEventStreamHeaders(c)
}
if i := strings.Index(string(data), "\n"); i >= 0 { ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
return i + 1, data[0:i], nil defer ticker.Stop()
}
if atEOF { stopChan := make(chan bool)
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
defer close(stopChan) defer close(stopChan)
defer close(dataChan) var (
var wg sync.WaitGroup lastStreamData string
go func() { mu sync.Mutex
wg.Add(1) )
defer wg.Done() gopool.Go(func() {
var streamItems []string // store stream items
for scanner.Scan() { for scanner.Scan() {
info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text() data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format if len(data) < 6 { // ignore blank line or wrong format
continue continue
@@ -53,55 +57,67 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if data[:6] != "data: " && data[:6] != "[DONE]" { if data[:6] != "data: " && data[:6] != "[DONE]" {
continue continue
} }
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { mu.Lock()
// send data timeout, stop the stream
common.LogError(c, "send data timeout, stop the stream")
break
}
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data) if lastStreamData != "" {
} err := service.StringData(c, lastStreamData)
} if err != nil {
// 计算token common.LogError(c, "streaming error: "+err.Error())
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
if streamResponse.Usage != nil {
if streamResponse.Usage.TotalTokens != 0 {
usage = *streamResponse.Usage
}
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
} }
} }
} else { lastStreamData = data
for _, streamResponse := range streamResponses { streamItems = append(streamItems, data)
if streamResponse.Usage != nil { }
if streamResponse.Usage.TotalTokens != 0 { mu.Unlock()
usage = *streamResponse.Usage }
} common.SafeSendBool(stopChan, true)
} })
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
shouldSendLastResp := true
var lastStreamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
if err == nil {
responseId = lastStreamResponse.Id
createAt = lastStreamResponse.Created
systemFingerprint = lastStreamResponse.GetSystemFingerprint()
model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
containStreamUsage = true
usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
shouldSendLastResp = false
}
}
}
if shouldSendLastResp {
service.StringData(c, lastStreamData)
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
//}
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString()) responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil { if choice.Delta.ToolCalls != nil {
@@ -116,67 +132,65 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
} }
} }
} }
case relayconstant.RelayModeCompletions: } else {
var streamResponses []dto.CompletionsStreamResponse for _, streamResponse := range streamResponses {
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) //if service.ValidUsage(streamResponse.Usage) {
if err != nil { // usage = streamResponse.Usage
// 一次性解析失败,逐个解析 // containStreamUsage = true
common.SysError("error unmarshalling stream response: " + err.Error()) //}
for _, item := range streamItems { for _, choice := range streamResponse.Choices {
var streamResponse dto.CompletionsStreamResponse responseTextBuilder.WriteString(choice.Delta.GetContentString())
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) if choice.Delta.ToolCalls != nil {
if err == nil { if len(choice.Delta.ToolCalls) > toolCount {
for _, choice := range streamResponse.Choices { toolCount = len(choice.Delta.ToolCalls)
responseTextBuilder.WriteString(choice.Text) }
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
} }
} }
} }
} else { }
for _, streamResponse := range streamResponses { }
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text) responseTextBuilder.WriteString(choice.Text)
} }
} }
} }
} } else {
if len(dataChan) > 0 { for _, streamResponse := range streamResponses {
// wait data out for _, choice := range streamResponse.Choices {
time.Sleep(2 * time.Second) responseTextBuilder.WriteString(choice.Text)
} }
common.SafeSendBool(stopChan, true)
}()
service.SetEventStreamHeaders(c)
isFirst := true
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
c.Stream(func(w io.Writer) bool {
select {
case <-ticker.C:
common.LogError(c, "reading data from upstream timeout")
return false
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
} }
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
} }
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
} }
wg.Wait()
return nil, &usage, responseTextBuilder.String(), toolCount if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
if info.ShouldIncludeUsage && !containStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)
}
service.Done(c)
resp.Body.Close()
return nil, usage
} }
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -213,11 +227,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range simpleResponse.Choices { for _, choice := range simpleResponse.Choices {
@@ -232,3 +242,134 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
} }
return nil, &simpleResponse.Usage return nil, &simpleResponse.Usage
} }
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.TotalTokens = info.PromptTokens
return nil, usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var audioResp dto.AudioResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &audioResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
var text string
switch responseFormat {
case "json":
text, err = getTextFromJSON(responseBody)
case "text":
text, err = getTextFromText(responseBody)
case "srt":
text, err = getTextFromSRT(responseBody)
case "verbose_json":
text, err = getTextFromVerboseJSON(responseBody)
case "vtt":
text, err = getTextFromVTT(responseBody)
}
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
func getTextFromVTT(body []byte) (string, error) {
return getTextFromSRT(body)
}
func getTextFromVerboseJSON(body []byte) (string, error) {
var whisperResponse dto.WhisperVerboseJSONResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
}
return whisperResponse.Text, nil
}
func getTextFromSRT(body []byte) (string, error) {
scanner := bufio.NewScanner(strings.NewReader(string(body)))
var builder strings.Builder
var textLine bool
for scanner.Scan() {
line := scanner.Text()
if textLine {
builder.WriteString(line)
textLine = false
continue
} else if strings.Contains(line, "-->") {
textLine = true
continue
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return builder.String(), nil
}
func getTextFromText(body []byte) (string, error) {
return strings.TrimSuffix(string(body), "\n"), nil
}
func getTextFromJSON(body []byte) (string, error) {
var whisperResponse dto.AudioResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
}
return whisperResponse.Text, nil
}

View File

@@ -15,12 +15,17 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -33,7 +38,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -10,18 +10,22 @@ import (
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -34,7 +38,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -54,11 +58,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string err, usage = openai.OaiStreamHandler(c, resp, info)
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@@ -23,12 +23,17 @@ type Adaptor struct {
Timestamp int64 Timestamp int64
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.Action = "ChatCompletions" a.Action = "ChatCompletions"
a.Version = "2023-09-01" a.Version = "2023-09-01"
a.Timestamp = common.GetTimestamp() a.Timestamp = common.GetTimestamp()
@@ -47,7 +52,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -16,12 +16,17 @@ type Adaptor struct {
request *dto.GeneralOpenAIRequest request *dto.GeneralOpenAIRequest
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -33,7 +38,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -14,12 +14,17 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -37,7 +42,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }

View File

@@ -153,18 +153,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage *dto.Usage var usage *dto.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(bufio.ScanLines)
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string) dataChan := make(chan string)
metaChan := make(chan string) metaChan := make(chan string)
stopChan := make(chan bool) stopChan := make(chan bool)

View File

@@ -10,18 +10,22 @@ import (
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me //TODO implement me
return nil, errors.New("not implemented")
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -35,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
@@ -55,13 +59,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string err, usage = openai.OaiStreamHandler(c, resp, info)
var toolCount int
err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info)
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@@ -1,7 +1,7 @@
package zhipu_4v package zhipu_4v
var ModelList = []string{ var ModelList = []string{
"glm-4", "glm-4v", "glm-3-turbo", "glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools",
} }
var ChannelName = "zhipu_4v" var ChannelName = "zhipu_4v"

View File

@@ -17,6 +17,7 @@ type RelayInfo struct {
TokenUnlimited bool TokenUnlimited bool
StartTime time.Time StartTime time.Time
FirstResponseTime time.Time FirstResponseTime time.Time
setFirstResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
RelayMode int RelayMode int
@@ -32,7 +33,7 @@ type RelayInfo struct {
} }
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel") channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
@@ -83,6 +84,13 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
info.IsStream = isStream info.IsStream = isStream
} }
func (info *RelayInfo) SetFirstResponseTime() {
if !info.setFirstResponse {
info.FirstResponseTime = time.Now()
info.setFirstResponse = true
}
}
type TaskRelayInfo struct { type TaskRelayInfo struct {
ChannelType int ChannelType int
ChannelId int ChannelId int
@@ -104,7 +112,7 @@ type TaskRelayInfo struct {
} }
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel") channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")

View File

@@ -1,50 +1,17 @@
package common package common
import ( import (
"encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
_ "image/gif" _ "image/gif"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
"io"
"net/http"
"one-api/common" "one-api/common"
"one-api/dto"
"strconv"
"strings" "strings"
) )
var StopFinishReason = "stop" var StopFinishReason = "stop"
func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: dto.OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var textResponse dto.TextResponseWithError
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody)
return
}
OpenAIErrorWithStatusCode.Error = textResponse.Error
return
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)

View File

@@ -13,6 +13,7 @@ const (
RelayModeModerations RelayModeModerations
RelayModeImagesGenerations RelayModeImagesGenerations
RelayModeEdits RelayModeEdits
RelayModeMidjourneyImagine RelayModeMidjourneyImagine
RelayModeMidjourneyDescribe RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend RelayModeMidjourneyBlend
@@ -22,16 +23,20 @@ const (
RelayModeMidjourneyTaskFetch RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskImageSeed RelayModeMidjourneyTaskImageSeed
RelayModeMidjourneyTaskFetchByCondition RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
RelayModeMidjourneyAction RelayModeMidjourneyAction
RelayModeMidjourneyModal RelayModeMidjourneyModal
RelayModeMidjourneyShorten RelayModeMidjourneyShorten
RelayModeSwapFace RelayModeSwapFace
RelayModeMidjourneyUpload
RelayModeAudioSpeech // tts
RelayModeAudioTranscription // whisper
RelayModeAudioTranslation // whisper
RelayModeSunoFetch RelayModeSunoFetch
RelayModeSunoFetchByID RelayModeSunoFetchByID
RelayModeSunoSubmit RelayModeSunoSubmit
RelayModeRerank RelayModeRerank
) )
@@ -77,6 +82,9 @@ func Path2RelayModeMidjourney(path string) int {
} else if strings.HasSuffix(path, "/mj/insight-face/swap") { } else if strings.HasSuffix(path, "/mj/insight-face/swap") {
// midjourney plus // midjourney plus
relayMode = RelayModeSwapFace relayMode = RelayModeSwapFace
} else if strings.HasSuffix(path, "/submit/upload-discord-images") {
// midjourney plus
relayMode = RelayModeMidjourneyUpload
} else if strings.HasSuffix(path, "/mj/submit/imagine") { } else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine relayMode = RelayModeMidjourneyImagine
} else if strings.HasSuffix(path, "/mj/submit/blend") { } else if strings.HasSuffix(path, "/mj/submit/blend") {

View File

@@ -1,13 +1,10 @@
package relay package relay
import ( import (
"bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
@@ -16,69 +13,71 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strings"
"time"
) )
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
tokenId := c.GetInt("token_id") audioRequest := &dto.AudioRequest{}
channelType := c.GetInt("channel") err := common.UnmarshalBodyReusable(c, audioRequest)
channelId := c.GetInt("channel_id") if err != nil {
userId := c.GetInt("id") return nil, err
group := c.GetString("group")
startTime := time.Now()
var audioRequest dto.TextToSpeechRequest
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err := common.UnmarshalBodyReusable(c, &audioRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
} else {
audioRequest = dto.TextToSpeechRequest{
Model: "whisper-1",
}
} }
//err := common.UnmarshalBodyReusable(c, &audioRequest) switch info.RelayMode {
case relayconstant.RelayModeAudioSpeech:
// request validation if audioRequest.Model == "" {
if audioRequest.Model == "" { return nil, errors.New("model is required")
return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
if strings.HasPrefix(audioRequest.Model, "tts-1") {
if audioRequest.Voice == "" {
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
} }
}
var err error
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if strings.HasPrefix(audioRequest.Model, "tts-1") {
if constant.ShouldCheckPromptSensitive() { if constant.ShouldCheckPromptSensitive() {
err = service.CheckSensitiveInput(audioRequest.Input) err := service.CheckSensitiveInput(audioRequest.Input)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) return nil, err
} }
} }
default:
if audioRequest.Model == "" {
audioRequest.Model = c.PostForm("model")
}
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
if audioRequest.ResponseFormat == "" {
audioRequest.ResponseFormat = "json"
}
}
return audioRequest, nil
}
func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
}
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model) promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
} }
preConsumedTokens = promptTokens preConsumedTokens = promptTokens
relayInfo.PromptTokens = promptTokens
} }
modelRatio := common.GetModelRatio(audioRequest.Model) modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(group) groupRatio := common.GetGroupRatio(relayInfo.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio) preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
} }
@@ -88,28 +87,12 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
preConsumedQuota = 0 preConsumedQuota = 0
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota) userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
} }
succeed := false
defer func() {
if succeed {
return
}
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func() {
go func() {
// negative means add quota back for token & user
returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota)
}()
}()
}
}()
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
if modelMapping != "" { if modelMapping != "" {
@@ -122,133 +105,44 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
audioRequest.Model = modelMap[audioRequest.Model] audioRequest.Model = modelMap[audioRequest.Model]
} }
} }
relayInfo.UpstreamModelName = audioRequest.Model
baseURL := common.ChannelBaseURLs[channelType] adaptor := GetAdaptor(relayInfo.ApiType)
requestURL := c.Request.URL.String() if adaptor == nil {
if c.GetString("base_url") != "" { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
baseURL = c.GetString("base_url")
} }
adaptor.Init(relayInfo)
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType) ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiVersion := relaycommon.GetAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
}
requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
} }
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
req.Header.Set("api-key", apiKey)
req.ContentLength = c.Request.ContentLength
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
err = req.Body.Close() statusCodeMappingStr := c.GetString("status_code_mapping")
if err != nil { if resp != nil {
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) if resp.StatusCode != http.StatusOK {
} returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
err = c.Request.Body.Close() openaiErr := service.RelayErrorHandler(resp)
if err != nil { // reset status code 重置状态码
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
} return openaiErr
if resp.StatusCode != http.StatusOK {
return relaycommon.RelayErrorHandler(resp)
}
succeed = true
var audioResponse dto.AudioResponse
defer func(ctx context.Context) {
go func() {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
quota := 0
if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = promptTokens
} else {
quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
}
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
if strings.HasPrefix(audioRequest.Model, "tts-1") {
} else {
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
contains, words := service.SensitiveWordContains(audioResponse.Text)
if contains {
return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest)
} }
} }
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
for k, v := range resp.Header { postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil return nil
} }

View File

@@ -2,7 +2,6 @@ package relay
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -14,72 +13,71 @@ import (
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strings" "strings"
"time"
) )
func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
tokenId := c.GetInt("token_id") imageRequest := &dto.ImageRequest{}
channelType := c.GetInt("channel") err := common.UnmarshalBodyReusable(c, imageRequest)
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
var imageRequest dto.ImageRequest
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) return nil, err
} }
if imageRequest.Prompt == "" {
if imageRequest.Model == "" { return nil, errors.New("prompt is required")
imageRequest.Model = "dall-e-3"
} }
if imageRequest.Size == "" { if strings.Contains(imageRequest.Size, "×") {
imageRequest.Size = "1024x1024" return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
} }
if imageRequest.N == 0 { if imageRequest.N == 0 {
imageRequest.N = 1 imageRequest.N = 1
} }
// Prompt validation if imageRequest.Size == "" {
if imageRequest.Prompt == "" { imageRequest.Size = "1024x1024"
return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
} }
if imageRequest.Model == "" {
if constant.ShouldCheckPromptSensitive() { imageRequest.Model = "dall-e-2"
err = service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
}
} }
if imageRequest.Quality == "" {
if strings.Contains(imageRequest.Size, "×") { imageRequest.Quality = "standard"
return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
} }
// Not "256x256", "512x512", or "1024x1024" // Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
} }
} else if imageRequest.Model == "dall-e-3" { } else if imageRequest.Model == "dall-e-3" {
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
} }
if imageRequest.N != 1 { //if imageRequest.N != 1 {
return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest) // return nil, errors.New("n must be 1")
//}
}
// N should between 1 and 10
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
//}
if constant.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
return nil, err
} }
} }
return imageRequest, nil
}
// N should between 1 and 10 func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { relayInfo := relaycommon.GenRelayInfo(c)
return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
} }
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" { if modelMapping != "" {
modelMap := make(map[string]string) modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap) err := json.Unmarshal([]byte(modelMapping), &modelMap)
@@ -88,31 +86,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
} }
if modelMap[imageRequest.Model] != "" { if modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model] imageRequest.Model = modelMap[imageRequest.Model]
isModelMapped = true
} }
} }
baseURL := common.ChannelBaseURLs[channelType] relayInfo.UpstreamModelName = imageRequest.Model
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := relaycommon.GetAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
}
var requestBody io.Reader
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelPrice, success := common.GetModelPrice(imageRequest.Model, true) modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
if !success { if !success {
@@ -121,8 +97,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
// per 1 modelRatio = $0.04 / 16 // per 1 modelRatio = $0.04 / 16
modelPrice = 0.0025 * modelRatio modelPrice = 0.0025 * modelRatio
} }
groupRatio := common.GetGroupRatio(group)
userQuota, err := model.CacheGetUserQuota(userId) groupRatio := common.GetGroupRatio(relayInfo.Group)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
sizeRatio := 1.0 sizeRatio := 1.0
// Size // Size
@@ -144,104 +121,67 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
} }
} }
quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
if userQuota-quota < 0 { if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo)
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
} }
token := c.Request.Header.Get("Authorization") jsonData, err := json.Marshal(convertedRequest)
if channelType == common.ChannelTypeAzure { // Azure authentication if err != nil {
token = strings.TrimPrefix(token, "Bearer ") return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
req.Header.Set("api-key", token)
} else {
req.Header.Set("Authorization", token)
} }
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) requestBody = bytes.NewBuffer(jsonData)
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := service.GetHttpClient().Do(req) statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
err = req.Body.Close() if resp != nil {
if err != nil { relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if resp.StatusCode != http.StatusOK {
return relaycommon.RelayErrorHandler(resp)
}
var textResponse dto.ImageResponse
defer func(ctx context.Context) {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
} }
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
quality := "normal"
if imageRequest.Quality == "hd" {
quality = "hd"
}
logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality)
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
} }
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) _, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
for k, v := range resp.Header { usage := &dto.Usage{
c.Writer.Header().Set(k, v[0]) PromptTokens: imageRequest.N,
TotalTokens: imageRequest.N,
} }
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body) quality := "standard"
if err != nil { if imageRequest.Quality == "hd" {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) quality = "hd"
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
} }
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
return nil return nil
} }

View File

@@ -30,7 +30,7 @@ func RelayMidjourneyImage(c *gin.Context) {
}) })
return return
} }
resp, err := http.Get(midjourneyTask.ImageUrl) resp, err := common.ProxiedHttpGet(midjourneyTask.ImageUrl, common.OutProxyUrl)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed", "error": "http_get_image_failed",
@@ -111,7 +111,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = "" midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled { if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" { if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
} }
@@ -382,6 +382,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
midjRequest.Action = constant.MjActionShorten midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionBlend midjRequest.Action = constant.MjActionBlend
} else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionUpload
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果 } else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := "" mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange { if relayMode == relayconstant.RelayModeMidjourneyChange {
@@ -547,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if err != nil { if err != nil {
common.SysError("get_channel_null: " + err.Error()) common.SysError("get_channel_null: " + err.Error())
} }
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled { if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance") model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
} }
} }
@@ -580,7 +582,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody = []byte(newBody) responseBody = []byte(newBody)
} }
} }
if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" {
midjourneyTask.Progress = "100%"
midjourneyTask.Status = "SUCCESS"
}
err = midjourneyTask.Insert() err = midjourneyTask.Insert()
if err != nil { if err != nil {
return &dto.MidjourneyResponse{ return &dto.MidjourneyResponse{
@@ -594,7 +599,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1) newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody) responseBody = []byte(newBody)
} }
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody)) bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))

View File

@@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
} }
} }
relayInfo.UpstreamModelName = textRequest.Model relayInfo.UpstreamModelName = textRequest.Model
modelPrice, success := common.GetModelPrice(textRequest.Model, false) modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group) groupRatio := common.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int var preConsumedQuota int
@@ -112,7 +112,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
} }
if !success { if !getModelPriceSuccess {
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 { if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens) preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
@@ -130,6 +130,12 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return openaiErr return openaiErr
} }
includeUsage := false
// 判断用户是否需要返回使用情况
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
includeUsage = true
}
// 如果不支持StreamOptions将StreamOptions设置为nil // 如果不支持StreamOptions将StreamOptions设置为nil
if !relayInfo.SupportStreamOptions || !textRequest.Stream { if !relayInfo.SupportStreamOptions || !textRequest.Stream {
textRequest.StreamOptions = nil textRequest.StreamOptions = nil
@@ -142,18 +148,18 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
} }
} }
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage { if includeUsage {
relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage relayInfo.ShouldIncludeUsage = true
} }
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil { if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
} }
adaptor.Init(relayInfo, *textRequest) adaptor.Init(relayInfo)
var requestBody io.Reader var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest) convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
} }
@@ -187,7 +193,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
return nil return nil
} }
@@ -199,6 +205,15 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model) promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
case relayconstant.RelayModeCompletions: case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
prompts := textRequest.Prompt
switch v := prompts.(type) {
case string:
prompts = v + textRequest.Suffix
case []string:
prompts = append(v, textRequest.Suffix)
}
promptTokens, err = service.CountTokenInput(prompts, textRequest.Model)
case relayconstant.RelayModeModerations: case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
@@ -247,13 +262,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if tokenQuota > 100*preConsumedQuota { if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌 // 令牌额度充足,信任令牌
preConsumedQuota = 0 preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
} }
} else { } else {
// in this case, we do not pre-consume quota // in this case, we do not pre-consume quota
// because the user has enough quota // because the user has enough quota
preConsumedQuota = 0 preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
} }
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
@@ -279,8 +294,15 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool) { modelPrice float64, usePrice bool, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
CompletionTokens: 0,
TotalTokens: relayInfo.PromptTokens,
}
extraContent += " ,(可能是请求出错)"
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens
@@ -300,7 +322,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
} }
totalTokens := promptTokens + completionTokens totalTokens := promptTokens + completionTokens
var logContent string var logContent string
if modelPrice == -1 { if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
} else { } else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
@@ -337,6 +359,12 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
if strings.HasPrefix(logModel, "gpt-4-gizmo") { if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*" logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", modelName) logContent += fmt.Sprintf(",模型 %s", modelName)
} else if strings.HasPrefix(logModel, "g-") {
logModel = "g-*"
logContent += fmt.Sprintf(",模型 %s", modelName)
}
if extraContent != "" {
logContent += ", " + extraContent
} }
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice) other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,

View File

@@ -66,7 +66,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
if adaptor == nil { if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
} }
adaptor.InitRerank(relayInfo, *rerankRequest) adaptor.Init(relayInfo)
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
if err != nil { if err != nil {
@@ -99,6 +99,6 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
} }
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
return nil return nil
} }

View File

@@ -18,13 +18,14 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus) apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout) apiRouter.GET("/about", controller.GetAbout)
//apiRouter.GET("/midjourney", controller.GetMidjourney) apiRouter.GET("/midjourney", controller.GetMidjourney)
apiRouter.GET("/home_page_content", controller.GetHomePageContent) apiRouter.GET("/home_page_content", controller.GetHomePageContent)
apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing) apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxDoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
@@ -32,13 +33,14 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin) apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind) apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
userRoute := apiRouter.Group("/user") userRoute := apiRouter.Group("/user")
{ {
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login) userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog) //userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout) userRoute.GET("/logout", controller.Logout)
userRoute.GET("/epay/notify", controller.EpayNotify)
selfRoute := userRoute.Group("/") selfRoute := userRoute.Group("/")
selfRoute.Use(middleware.UserAuth()) selfRoute.Use(middleware.UserAuth())
@@ -49,8 +51,8 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.DELETE("/self", controller.DeleteSelf) selfRoute.DELETE("/self", controller.DeleteSelf)
selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode) selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp) selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
selfRoute.POST("/pay", controller.RequestEpay) selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestPayLink)
selfRoute.POST("/amount", controller.RequestAmount) selfRoute.POST("/amount", controller.RequestAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota) selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
} }
@@ -64,7 +66,7 @@ func SetApiRouter(router *gin.Engine) {
adminRoute.POST("/", controller.CreateUser) adminRoute.POST("/", controller.CreateUser)
adminRoute.POST("/manage", controller.ManageUser) adminRoute.POST("/manage", controller.ManageUser)
adminRoute.PUT("/", controller.UpdateUser) adminRoute.PUT("/", controller.UpdateUser)
adminRoute.DELETE("/:id", controller.DeleteUser) adminRoute.DELETE("/:id", controller.HardDeleteUser)
} }
} }
optionRoute := apiRouter.Group("/option") optionRoute := apiRouter.Group("/option")

View File

@@ -79,5 +79,6 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney) relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney)
} }
} }

View File

@@ -74,14 +74,14 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
return false return false
} }
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool { func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
if !common.AutomaticEnableChannelEnabled { if !common.AutomaticEnableChannelEnabled {
return false return false
} }
if err != nil { if err != nil {
return false return false
} }
if openAIErr != nil { if openaiWithStatusErr != nil {
return false return false
} }
if status != common.ChannelStatusAutoDisabled { if status != common.ChannelStatusAutoDisabled {

View File

@@ -1,12 +0,0 @@
package service
import (
"one-api/constant"
)
func GetCallbackAddress() string {
if constant.CustomCallbackAddress == "" {
return constant.ServerAddress
}
return constant.CustomCallbackAddress
}

View File

@@ -56,10 +56,9 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
Error: dto.OpenAIError{ Error: dto.OpenAIError{
Message: "", Type: "upstream_error",
Type: "upstream_error", Code: "bad_response_status_code",
Code: "bad_response_status_code", Param: strconv.Itoa(resp.StatusCode),
Param: strconv.Itoa(resp.StatusCode),
}, },
} }
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)

Some files were not shown because too many files have changed in this diff Show More