mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-17 19:13:42 +08:00
Compare commits
234 Commits
v0.1.3.1
...
v0.2.1.0-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34bf8f8945 | ||
|
|
2d1d1b4631 | ||
|
|
fbdb17022c | ||
|
|
497cc32634 | ||
|
|
462c328d4b | ||
|
|
257cfc2390 | ||
|
|
fed1a1d6a3 | ||
|
|
fc9f8c8e8a | ||
|
|
f3f36dafbd | ||
|
|
aaf3a1f07b | ||
|
|
c040fa229d | ||
|
|
1cd1e54be4 | ||
|
|
3db64afc7f | ||
|
|
bc9cfa5da0 | ||
|
|
660b9b3c99 | ||
|
|
cdf2087952 | ||
|
|
4b60528c5f | ||
|
|
9025756b56 | ||
|
|
2ea6009954 | ||
|
|
a33f685f3c | ||
|
|
3d0f77ffb6 | ||
|
|
5ce8e6dab6 | ||
|
|
5a5b7d618d | ||
|
|
ad8ce915ec | ||
|
|
456fb875de | ||
|
|
3e90b6d516 | ||
|
|
d6e373fbe4 | ||
|
|
224746b45a | ||
|
|
ac827b1862 | ||
|
|
658bf2ad57 | ||
|
|
c25f48b7c5 | ||
|
|
290dcf7587 | ||
|
|
278fd39195 | ||
|
|
aa23c51a53 | ||
|
|
87919b032d | ||
|
|
f7a4f18aff | ||
|
|
706449dede | ||
|
|
36d164be0e | ||
|
|
d80a7d3c97 | ||
|
|
44a8ade4ba | ||
|
|
2cca2a989e | ||
|
|
3065bf92ae | ||
|
|
2e595bdafb | ||
|
|
49df4b6eed | ||
|
|
5c39f54040 | ||
|
|
786ccc7da0 | ||
|
|
8eedad9470 | ||
|
|
319e97d677 | ||
|
|
6114c9bb96 | ||
|
|
3cf2f0d5cb | ||
|
|
2a345ae070 | ||
|
|
d8c91fa448 | ||
|
|
cc8cc8b386 | ||
|
|
1587ea565b | ||
|
|
a7a1fc615d | ||
|
|
b2a280c1ec | ||
|
|
f1fb7b32a3 | ||
|
|
3800dc219e | ||
|
|
72962e988f | ||
|
|
01e3acfada | ||
|
|
f671176da0 | ||
|
|
2d36dee17c | ||
|
|
6eb30ec3e6 | ||
|
|
0b3520e3c8 | ||
|
|
63304a5b2d | ||
|
|
66e30f4115 | ||
|
|
0618f03c68 | ||
|
|
962dc984f4 | ||
|
|
15e7307320 | ||
|
|
951383c371 | ||
|
|
87b6210045 | ||
|
|
525fc1b3b7 | ||
|
|
58f2cf3a79 | ||
|
|
06c86397e1 | ||
|
|
21f48b55e0 | ||
|
|
f823b4d4d8 | ||
|
|
93be61aaf3 | ||
|
|
a500097b36 | ||
|
|
67332bc8df | ||
|
|
d0acecb2ab | ||
|
|
a825699e9a | ||
|
|
a70ca53449 | ||
|
|
c33b1522cc | ||
|
|
ff7da08bad | ||
|
|
3e03c5a742 | ||
|
|
d9344d79cf | ||
|
|
c4b3d3a975 | ||
|
|
031957714a | ||
|
|
3f808be254 | ||
|
|
9b64f4a34a | ||
|
|
222a55387d | ||
|
|
492001a8b2 | ||
|
|
d7e25e1604 | ||
|
|
7d64f30f4d | ||
|
|
9e157ed802 | ||
|
|
cfabf8a656 | ||
|
|
c7658b70d1 | ||
|
|
d5e93e788d | ||
|
|
dd71946047 | ||
|
|
b736de7189 | ||
|
|
5e7774cf08 | ||
|
|
a232afe9fd | ||
|
|
eb6257a8d8 | ||
|
|
47b9f48544 | ||
|
|
2db4282666 | ||
|
|
64b9d3b58c | ||
|
|
7a663d26ec | ||
|
|
bec21ade9d | ||
|
|
4c4e087060 | ||
|
|
beadb98a8c | ||
|
|
615d109d70 | ||
|
|
0da49fa446 | ||
|
|
578b5f6536 | ||
|
|
f63ad9c03c | ||
|
|
0fb98e44a7 | ||
|
|
652bb4a53c | ||
|
|
a26b9a9bff | ||
|
|
f82aa956bd | ||
|
|
b8c053c37f | ||
|
|
2581b37394 | ||
|
|
f65477d054 | ||
|
|
4b952b8582 | ||
|
|
a6ba1d01d9 | ||
|
|
2d2fec24d0 | ||
|
|
d34b55c154 | ||
|
|
25ec99913b | ||
|
|
7f22d58574 | ||
|
|
dfdeadf1a5 | ||
|
|
6ab1b3a524 | ||
|
|
4917e5a92f | ||
|
|
f62dcbf669 | ||
|
|
299911d4cd | ||
|
|
84e0544604 | ||
|
|
2786a6b539 | ||
|
|
9b5353a81a | ||
|
|
bc5a54df59 | ||
|
|
d704902b70 | ||
|
|
614220a0fb | ||
|
|
98c1f66d61 | ||
|
|
a77fbc0fa2 | ||
|
|
44361d75e8 | ||
|
|
9b2e5c2978 | ||
|
|
d3399d68f6 | ||
|
|
3d10c9f090 | ||
|
|
d5ffaf2502 | ||
|
|
2ad591411e | ||
|
|
728dbed28d | ||
|
|
fd3a41bacb | ||
|
|
37c0c8ebdd | ||
|
|
95d8059c90 | ||
|
|
c012306400 | ||
|
|
1e1a53e4d2 | ||
|
|
30d48ea473 | ||
|
|
9e5c636490 | ||
|
|
d53d3386e9 | ||
|
|
c3a01decd8 | ||
|
|
c29680b301 | ||
|
|
ac7407ce9c | ||
|
|
3ea061a820 | ||
|
|
d4df1960b2 | ||
|
|
97dd80541b | ||
|
|
bf241b218f | ||
|
|
7ab6c6c303 | ||
|
|
8fe8340b6e | ||
|
|
26ef906c61 | ||
|
|
655dfe0d09 | ||
|
|
f43b268520 | ||
|
|
37113c0e96 | ||
|
|
3c3c53051d | ||
|
|
6a83c8ad86 | ||
|
|
0640dd81fd | ||
|
|
cb50fcaffe | ||
|
|
eca48268b2 | ||
|
|
4a0af1ea3c | ||
|
|
c2965eb835 | ||
|
|
4186880e4c | ||
|
|
280c63e1d4 | ||
|
|
9e0a54e943 | ||
|
|
0e06be8c3e | ||
|
|
931d22c96f | ||
|
|
6413bf342a | ||
|
|
626217fbd4 | ||
|
|
9de2d21e1a | ||
|
|
3ab4f145db | ||
|
|
da73dca9a7 | ||
|
|
415d296171 | ||
|
|
d5c5c30312 | ||
|
|
fb39b6b30e | ||
|
|
92c1ed7f1d | ||
|
|
d160736a49 | ||
|
|
fe7f42fc2e | ||
|
|
5cb933a278 | ||
|
|
912f46fcd2 | ||
|
|
1ad1112351 | ||
|
|
1f0a48a879 | ||
|
|
d2de675564 | ||
|
|
a5cfeeaa63 | ||
|
|
37b307a784 | ||
|
|
02d5a5f16d | ||
|
|
699fe256d0 | ||
|
|
54088bc664 | ||
|
|
06b746a740 | ||
|
|
4c43012e6c | ||
|
|
a16e949318 | ||
|
|
72194bda98 | ||
|
|
ca5c4d2dbd | ||
|
|
cbe0a1fa76 | ||
|
|
d6d8ad2b7e | ||
|
|
7297c6269f | ||
|
|
cb92d6fd5f | ||
|
|
e5cea80103 | ||
|
|
84144306a8 | ||
|
|
ac64fd26ad | ||
|
|
a8f0c5dab2 | ||
|
|
2b076eaed2 | ||
|
|
194ff1ac57 | ||
|
|
84cac72a45 | ||
|
|
413d4f0a66 | ||
|
|
690d2e6ba2 | ||
|
|
140e843d93 | ||
|
|
feb40db2bc | ||
|
|
b4645d1019 | ||
|
|
a3687b72f8 | ||
|
|
6013219f5b | ||
|
|
5b18cd6b0a | ||
|
|
9b421478c1 | ||
|
|
569751f5db | ||
|
|
608c945ae6 | ||
|
|
05beade3ad | ||
|
|
81167c4322 | ||
|
|
6fa1837c24 | ||
|
|
ba0c5bb4d9 | ||
|
|
5482fff62d | ||
|
|
1cde58c089 |
1
.github/workflows/docker-image-amd64.yml
vendored
1
.github/workflows/docker-image-amd64.yml
vendored
@@ -4,7 +4,6 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
|
||||
1
.github/workflows/docker-image-arm64.yml
vendored
1
.github/workflows/docker-image-arm64.yml
vendored
@@ -4,7 +4,6 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
|
||||
@@ -5,7 +5,7 @@ COPY web/package.json .
|
||||
RUN npm install
|
||||
COPY ./web .
|
||||
COPY ./VERSION .
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||
|
||||
FROM golang AS builder2
|
||||
|
||||
@@ -17,7 +17,7 @@ WORKDIR /build
|
||||
ADD go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
COPY --from=builder /build/build ./web/build
|
||||
COPY --from=builder /build/dist ./web/dist
|
||||
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
|
||||
|
||||
FROM alpine
|
||||
|
||||
314
Midjourney.md
314
Midjourney.md
@@ -2,290 +2,66 @@
|
||||
|
||||
**简介**:Midjourney Proxy API文档
|
||||
|
||||
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
|
||||
## 模型列表
|
||||
|
||||
### midjourney-proxy支持
|
||||
|
||||
- mj_imagine (绘图)
|
||||
- mj_variation (变换)
|
||||
- mj_reroll (重绘)
|
||||
- mj_blend (混合)
|
||||
- mj_upscale (放大)
|
||||
- mj_describe (图生文)
|
||||
|
||||
### 仅midjourney-proxy-plus支持
|
||||
|
||||
- mj_zoom (比例变焦)
|
||||
- mj_shorten (提示词缩短)
|
||||
- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
|
||||
- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
|
||||
- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
|
||||
- mj_high_variation (强变换)
|
||||
- mj_low_variation (弱变换)
|
||||
- mj_pan (平移)
|
||||
- swap_face (换脸)
|
||||
|
||||
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
|
||||
```json
|
||||
{
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_modal": 0.1,
|
||||
"mj_zoom": 0.1,
|
||||
"mj_shorten": 0.1,
|
||||
"mj_high_variation": 0.1,
|
||||
"mj_low_variation": 0.1,
|
||||
"mj_pan": 0.1,
|
||||
"mj_inpaint": 0,
|
||||
"mj_custom_zoom": 0,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05
|
||||
"mj_upscale": 0.05,
|
||||
"swap_face": 0.05
|
||||
}
|
||||
```
|
||||
其中mj_inpaint和mj_custom_zoom的价格设置为0,是因为这两个模型需要搭配mj_modal使用,所以价格由mj_modal决定。
|
||||
|
||||
## 渠道设置
|
||||
|
||||
### 对接 midjourney-proxy
|
||||
1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
|
||||
2. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
|
||||
### 对接 midjourney-proxy(plus)
|
||||
|
||||
1.
|
||||
|
||||
部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
|
||||
|
||||
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
|
||||
,模型请参考上方模型列表
|
||||
3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
|
||||
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
|
||||
|
||||
### 对接上游new api
|
||||
1. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
|
||||
2. 地址填写上游new api的地址,例如:http://localhost:8080
|
||||
3. 密钥填写上游new api的密钥
|
||||
|
||||
## 任务提交
|
||||
|
||||
### 绘图变化
|
||||
|
||||
**接口地址**:`/mj/submit/change`
|
||||
|
||||
**请求方式**:`POST`
|
||||
|
||||
**请求数据类型**:`application/json`
|
||||
|
||||
**响应数据类型**:`*/*`
|
||||
|
||||
**接口描述**:
|
||||
|
||||
**请求示例**:
|
||||
|
||||
```javascript
|
||||
{
|
||||
"action"
|
||||
:
|
||||
"UPSCALE",
|
||||
"index"
|
||||
:
|
||||
1,
|
||||
"notifyHook"
|
||||
:
|
||||
"",
|
||||
"state"
|
||||
:
|
||||
"",
|
||||
"taskId"
|
||||
:
|
||||
"1320098173412546"
|
||||
}
|
||||
```
|
||||
|
||||
**请求参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|
||||
|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------|
|
||||
| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 |
|
||||
|   action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | |
|
||||
|   index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | |
|
||||
|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
|
||||
|   state | 自定义参数 | | false | string | |
|
||||
|   taskId | 任务ID | | true | string | |
|
||||
|
||||
**响应状态**:
|
||||
|
||||
| 状态码 | 说明 | schema |
|
||||
|-----|--------------|--------|
|
||||
| 200 | OK | 提交结果 |
|
||||
| 201 | Created | |
|
||||
| 401 | Unauthorized | |
|
||||
| 403 | Forbidden | |
|
||||
| 404 | Not Found | |
|
||||
|
||||
**响应参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 类型 | schema |
|
||||
|-------------|-------------------------------------------|----------------|----------------|
|
||||
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
|
||||
| description | 描述 | string | |
|
||||
| properties | 扩展字段 | object | |
|
||||
| result | 任务ID | string | |
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```javascript
|
||||
{
|
||||
"code"
|
||||
:
|
||||
1,
|
||||
"description"
|
||||
:
|
||||
"提交成功",
|
||||
"properties"
|
||||
:
|
||||
{
|
||||
}
|
||||
,
|
||||
"result"
|
||||
:
|
||||
1320098173412546
|
||||
}
|
||||
```
|
||||
|
||||
### 提交Imagine任务
|
||||
|
||||
**接口地址**:`/mj/submit/imagine`
|
||||
|
||||
**请求方式**:`POST`
|
||||
|
||||
**请求数据类型**:`application/json`
|
||||
|
||||
**响应数据类型**:`*/*`
|
||||
|
||||
**接口描述**:
|
||||
|
||||
**请求示例**:
|
||||
|
||||
```javascript
|
||||
{
|
||||
"base64"
|
||||
:
|
||||
"",
|
||||
"notifyHook"
|
||||
:
|
||||
"",
|
||||
"prompt"
|
||||
:
|
||||
"Cat",
|
||||
"state"
|
||||
:
|
||||
""
|
||||
}
|
||||
```
|
||||
|
||||
**请求参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|
||||
|------------------------|-------------------------|------|-------|-------------|-------------|
|
||||
| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 |
|
||||
|   base64 | 垫图base64 | | false | string | |
|
||||
|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
|
||||
|   prompt | 提示词 | | true | string | |
|
||||
|   state | 自定义参数 | | false | string | |
|
||||
|
||||
**响应状态**:
|
||||
|
||||
| 状态码 | 说明 | schema |
|
||||
|-----|--------------|--------|
|
||||
| 200 | OK | 提交结果 |
|
||||
| 201 | Created | |
|
||||
| 401 | Unauthorized | |
|
||||
| 403 | Forbidden | |
|
||||
| 404 | Not Found | |
|
||||
|
||||
**响应参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 类型 | schema |
|
||||
|-------------|-------------------------------------------|----------------|----------------|
|
||||
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
|
||||
| description | 描述 | string | |
|
||||
| properties | 扩展字段 | object | |
|
||||
| result | 任务ID | string | |
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```javascript
|
||||
{
|
||||
"code"
|
||||
:
|
||||
1,
|
||||
"description"
|
||||
:
|
||||
"提交成功",
|
||||
"properties"
|
||||
:
|
||||
{
|
||||
}
|
||||
,
|
||||
"result"
|
||||
:
|
||||
1320098173412546
|
||||
}
|
||||
```
|
||||
|
||||
## 任务查询
|
||||
|
||||
### 指定ID获取任务
|
||||
|
||||
**接口地址**:`/mj/task/{id}/fetch`
|
||||
|
||||
**请求方式**:`GET`
|
||||
|
||||
**请求数据类型**:`application/x-www-form-urlencoded`
|
||||
|
||||
**响应数据类型**:`*/*`
|
||||
|
||||
**接口描述**:
|
||||
|
||||
**请求参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|
||||
|------|------|------|-------|--------|--------|
|
||||
| id | 任务ID | path | false | string | |
|
||||
|
||||
**响应状态**:
|
||||
|
||||
| 状态码 | 说明 | schema |
|
||||
|-----|--------------|--------|
|
||||
| 200 | OK | 任务 |
|
||||
| 401 | Unauthorized | |
|
||||
| 403 | Forbidden | |
|
||||
| 404 | Not Found | |
|
||||
|
||||
**响应参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 类型 | schema |
|
||||
|-------------|----------------------------------------------------------|----------------|----------------|
|
||||
| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | |
|
||||
| description | 任务描述 | string | |
|
||||
| failReason | 失败原因 | string | |
|
||||
| finishTime | 结束时间 | integer(int64) | integer(int64) |
|
||||
| id | 任务ID | string | |
|
||||
| imageUrl | 图片url | string | |
|
||||
| progress | 任务进度 | string | |
|
||||
| prompt | 提示词 | string | |
|
||||
| promptEn | 提示词-英文 | string | |
|
||||
| startTime | 开始执行时间 | integer(int64) | integer(int64) |
|
||||
| state | 自定义参数 | string | |
|
||||
| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | |
|
||||
| submitTime | 提交时间 | integer(int64) | integer(int64) |
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```javascript
|
||||
{
|
||||
"action"
|
||||
:
|
||||
"",
|
||||
"description"
|
||||
:
|
||||
"",
|
||||
"failReason"
|
||||
:
|
||||
"",
|
||||
"finishTime"
|
||||
:
|
||||
0,
|
||||
"id"
|
||||
:
|
||||
"",
|
||||
"imageUrl"
|
||||
:
|
||||
"",
|
||||
"progress"
|
||||
:
|
||||
"",
|
||||
"prompt"
|
||||
:
|
||||
"",
|
||||
"promptEn"
|
||||
:
|
||||
"",
|
||||
"startTime"
|
||||
:
|
||||
0,
|
||||
"state"
|
||||
:
|
||||
"",
|
||||
"status"
|
||||
:
|
||||
"",
|
||||
"submitTime"
|
||||
:
|
||||
0
|
||||
}
|
||||
```
|
||||
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
|
||||
2. 地址填写上游new api的地址,例如:http://localhost:3000
|
||||
3. 密钥填写上游new api的密钥
|
||||
74
README.md
74
README.md
@@ -14,30 +14,60 @@
|
||||
> 最新版Docker镜像 calciumion/new-api:latest
|
||||
> 更新指令 docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
|
||||
|
||||
## 此分叉版本的主要变更
|
||||
## 主要变更
|
||||
此分叉版本的主要变更如下:
|
||||
|
||||
1. 全新的UI界面(部分界面还待更新)
|
||||
2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持:
|
||||
+ [x] /mj/submit/imagine
|
||||
+ [x] /mj/submit/change
|
||||
+ [x] /mj/submit/blend
|
||||
+ [x] /mj/submit/describe
|
||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||
+ [x] /task/list-by-condition
|
||||
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md),支持的接口如下:
|
||||
+ [x] /mj/submit/imagine
|
||||
+ [x] /mj/submit/change
|
||||
+ [x] /mj/submit/blend
|
||||
+ [x] /mj/submit/describe
|
||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||
+ [x] /task/list-by-condition
|
||||
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||
+ [x] /mj/submit/modal
|
||||
+ [x] /mj/submit/shorten
|
||||
+ [x] /mj/task/{id}/image-seed
|
||||
+ [x] /mj/insight-face/swap (InsightFace)
|
||||
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
||||
+ [x] 易支付
|
||||
+ [x] 易支付
|
||||
4. 支持用key查询使用额度:
|
||||
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用情况,方便二次分销
|
||||
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
|
||||
5. 渠道显示已使用额度,支持指定组织访问
|
||||
6. 分页支持选择每页显示数量
|
||||
7. 支持 gpt-4-1106-vision-preview,dall-e-3,tts-1
|
||||
8. 支持第三方模型 **gps** (gpt-4-gizmo-*),在渠道中添加自定义模型gpt-4-gizmo-*即可
|
||||
9. 兼容原版One API的数据库,可直接使用原版数据库(one-api.db)
|
||||
10. 支持模型按次数收费,可在 系统设置-运营设置 中设置
|
||||
11. 支持gemini-pro,gemini-pro-vision模型
|
||||
12. 支持渠道**加权随机**
|
||||
13. 数据看板
|
||||
14. 可设置令牌能调用的模型
|
||||
7. 兼容原版One API的数据库,可直接使用原版数据库(one-api.db)
|
||||
8. 支持模型按次数收费,可在 系统设置-运营设置 中设置
|
||||
9. 支持渠道**加权随机**
|
||||
10. 数据看板
|
||||
11. 可设置令牌能调用的模型
|
||||
12. 支持Telegram授权登录。
|
||||
1. 系统设置-配置登录注册-允许通过Telegram登录
|
||||
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
||||
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||
|
||||
## 模型支持
|
||||
此版本额外支持以下模型:
|
||||
1. 第三方模型 **gps** (gpt-4-gizmo-*)
|
||||
2. 智谱glm-4v,glm-4v识图
|
||||
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
|
||||
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
||||
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
||||
6. [零一万物](https://platform.lingyiwanwu.com/)
|
||||
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
|
||||
## 渠道重试
|
||||
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。
|
||||
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
|
||||
### 缓存设置方法
|
||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
||||
|
||||
|
||||
## 部署
|
||||
### 基于 Docker 进行部署
|
||||
@@ -79,6 +109,12 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||

|
||||

|
||||
|
||||
## 相关项目
|
||||
- [One API](https://github.com/songquanpeng/one-api):原版项目
|
||||
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy):Midjourney接口支持
|
||||
- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代 AI 一站式 B/C 端解决方案
|
||||
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool):用key查询使用额度
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
|
||||
@@ -9,14 +9,19 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Pay Settings
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
var EpayId = ""
|
||||
var EpayKey = ""
|
||||
var Price = 7.3
|
||||
var MinTopUp = 1
|
||||
|
||||
var StartTime = time.Now().Unix() // unit: second
|
||||
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
||||
var SystemName = "New API"
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var PayAddress = ""
|
||||
var EpayId = ""
|
||||
var EpayKey = ""
|
||||
var Price = 7.3
|
||||
var Footer = ""
|
||||
var Logo = ""
|
||||
var TopUpLink = ""
|
||||
@@ -29,6 +34,7 @@ var DrawingEnabled = true
|
||||
var DataExportEnabled = true
|
||||
var DataExportInterval = 5 // unit: minute
|
||||
var DataExportDefaultTime = "hour" // unit: minute
|
||||
var DefaultCollapseSidebar = false // default value of collapse sidebar
|
||||
|
||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||
|
||||
@@ -45,10 +51,12 @@ var PasswordRegisterEnabled = true
|
||||
var EmailVerificationEnabled = false
|
||||
var GitHubOAuthEnabled = false
|
||||
var WeChatAuthEnabled = false
|
||||
var TelegramOAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
var RegisterEnabled = true
|
||||
|
||||
var EmailDomainRestrictionEnabled = false
|
||||
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
|
||||
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
|
||||
var EmailDomainWhitelist = []string{
|
||||
"gmail.com",
|
||||
"163.com",
|
||||
@@ -68,6 +76,7 @@ var LogConsumeEnabled = true
|
||||
|
||||
var SMTPServer = ""
|
||||
var SMTPPort = 587
|
||||
var SMTPSSLEnabled = false
|
||||
var SMTPAccount = ""
|
||||
var SMTPFrom = ""
|
||||
var SMTPToken = ""
|
||||
@@ -82,6 +91,9 @@ var WeChatAccountQRCodeImageURL = ""
|
||||
var TurnstileSiteKey = ""
|
||||
var TurnstileSecretKey = ""
|
||||
|
||||
var TelegramBotToken = ""
|
||||
var TelegramBotName = ""
|
||||
|
||||
var QuotaForNewUser = 0
|
||||
var QuotaForInviter = 0
|
||||
var QuotaForInvitee = 0
|
||||
@@ -100,7 +112,7 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
|
||||
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
|
||||
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
||||
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
@@ -176,10 +188,10 @@ const (
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeAPI2D = 2
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeCloseAI = 4
|
||||
ChannelTypeOpenAISB = 5
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
@@ -199,6 +211,10 @@ const (
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
@@ -206,7 +222,7 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"https://api.closeai-proxy.xyz", // 4
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
@@ -226,5 +242,12 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.cloud.tencent.com", //23
|
||||
"", //24
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
"", //28
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
}
|
||||
|
||||
@@ -2,5 +2,6 @@ package common
|
||||
|
||||
var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
var UsingMySQL = false
|
||||
|
||||
var SQLitePath = "one-api.db?_busy_timeout=5000"
|
||||
|
||||
@@ -24,7 +24,7 @@ func SendEmail(subject string, receiver string, content string) error {
|
||||
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
|
||||
to := strings.Split(receiver, ";")
|
||||
var err error
|
||||
if SMTPPort == 465 {
|
||||
if SMTPPort == 465 || SMTPSSLEnabled {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: SMTPServer,
|
||||
|
||||
@@ -5,18 +5,37 @@ import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
const KeyRequestBody = "key_request_body"
|
||||
|
||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||
requestBody, _ := c.Get(KeyRequestBody)
|
||||
if requestBody != nil {
|
||||
return requestBody.([]byte), nil
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
_ = c.Request.Body.Close()
|
||||
c.Set(KeyRequestBody, requestBody)
|
||||
return requestBody.([]byte), nil
|
||||
}
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
requestBody, err := GetRequestBody(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,14 +5,14 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/chai2010/webp"
|
||||
"golang.org/x/image/webp"
|
||||
"image"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
||||
// 去除base64数据的URL前缀(如果有)
|
||||
if idx := strings.Index(base64String, ","); idx != -1 {
|
||||
base64String = base64String[idx+1:]
|
||||
@@ -22,13 +22,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64String)
|
||||
if err != nil {
|
||||
fmt.Println("Error: Failed to decode base64 string")
|
||||
return image.Config{}, "", err
|
||||
return image.Config{}, "", "", err
|
||||
}
|
||||
|
||||
// 创建一个bytes.Buffer用于存储解码后的数据
|
||||
reader := bytes.NewReader(decodedData)
|
||||
config, format, err := getImageConfig(reader)
|
||||
return config, format, err
|
||||
return config, format, base64String, err
|
||||
}
|
||||
|
||||
func IsImageUrl(url string) (bool, error) {
|
||||
@@ -42,6 +42,7 @@ func IsImageUrl(url string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
isImage, err := IsImageUrl(url)
|
||||
if !isImage {
|
||||
|
||||
@@ -3,17 +3,16 @@ package common
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ModelRatio
|
||||
// modelRatio
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
||||
// https://openai.com/pricing
|
||||
// TODO: when a new api is enabled, check the pricing here
|
||||
// 1 === $0.002 / 1K tokens
|
||||
// 1 === ¥0.014 / 1k tokens
|
||||
var ModelRatio = map[string]float64{
|
||||
var DefaultModelRatio = map[string]float64{
|
||||
//"midjourney": 50,
|
||||
"gpt-4-gizmo-*": 15,
|
||||
"gpt-4": 15,
|
||||
@@ -27,7 +26,7 @@ var ModelRatio = map[string]float64{
|
||||
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo-0301": 0.75,
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||
@@ -61,42 +60,75 @@ var ModelRatio = map[string]float64{
|
||||
"text-moderation-latest": 0.1,
|
||||
"dall-e-2": 8,
|
||||
"dall-e-3": 16,
|
||||
"claude-instant-1": 0.815, // $1.63 / 1M tokens
|
||||
"claude-2": 5.51, // $11.02 / 1M tokens
|
||||
"claude-instant-1": 0.4, // $0.8 / 1M tokens
|
||||
"claude-2.0": 4, // $8 / 1M tokens
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-1.0-pro-vision-001": 1,
|
||||
"gemini-1.0-pro-001": 1,
|
||||
"gemini-1.5-pro": 1,
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||
"glm-4": 7.143, // ¥0.1 / 1k tokens
|
||||
"glm-4v": 7.143, // ¥0.1 / 1k tokens
|
||||
"glm-3-turbo": 0.3572,
|
||||
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
|
||||
"qwen-plus": 10, // ¥0.14 / 1k tokens
|
||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
||||
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||
// https://platform.lingyiwanwu.com/docs#-计费单元
|
||||
// 已经按照 7.2 来换算美元价格
|
||||
"yi-34b-chat-0205": 0.018,
|
||||
"yi-34b-chat-200k": 0.0864,
|
||||
"yi-vl-plus": 0.0432,
|
||||
}
|
||||
|
||||
var ModelPrice = map[string]float64{
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
var DefaultModelPrice = map[string]float64{
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_modal": 0.1,
|
||||
"mj_zoom": 0.1,
|
||||
"mj_shorten": 0.1,
|
||||
"mj_high_variation": 0.1,
|
||||
"mj_low_variation": 0.1,
|
||||
"mj_pan": 0.1,
|
||||
"mj_inpaint": 0,
|
||||
"mj_custom_zoom": 0,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
"swap_face": 0.05,
|
||||
}
|
||||
|
||||
var modelPrice map[string]float64 = nil
|
||||
var modelRatio map[string]float64 = nil
|
||||
|
||||
func ModelPrice2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(ModelPrice)
|
||||
if modelPrice == nil {
|
||||
modelPrice = DefaultModelPrice
|
||||
}
|
||||
jsonBytes, err := json.Marshal(modelPrice)
|
||||
if err != nil {
|
||||
SysError("error marshalling model price: " + err.Error())
|
||||
}
|
||||
@@ -104,15 +136,18 @@ func ModelPrice2JSONString() string {
|
||||
}
|
||||
|
||||
func UpdateModelPriceByJSONString(jsonStr string) error {
|
||||
ModelPrice = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &ModelPrice)
|
||||
modelPrice = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &modelPrice)
|
||||
}
|
||||
|
||||
func GetModelPrice(name string, printErr bool) float64 {
|
||||
if modelPrice == nil {
|
||||
modelPrice = DefaultModelPrice
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
price, ok := ModelPrice[name]
|
||||
price, ok := modelPrice[name]
|
||||
if !ok {
|
||||
if printErr {
|
||||
SysError("model price not found: " + name)
|
||||
@@ -123,7 +158,10 @@ func GetModelPrice(name string, printErr bool) float64 {
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(ModelRatio)
|
||||
if modelRatio == nil {
|
||||
modelRatio = DefaultModelRatio
|
||||
}
|
||||
jsonBytes, err := json.Marshal(modelRatio)
|
||||
if err != nil {
|
||||
SysError("error marshalling model ratio: " + err.Error())
|
||||
}
|
||||
@@ -131,15 +169,18 @@ func ModelRatio2JSONString() string {
|
||||
}
|
||||
|
||||
func UpdateModelRatioByJSONString(jsonStr string) error {
|
||||
ModelRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &ModelRatio)
|
||||
modelRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &modelRatio)
|
||||
}
|
||||
|
||||
func GetModelRatio(name string) float64 {
|
||||
if modelRatio == nil {
|
||||
modelRatio = DefaultModelRatio
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
ratio, ok := ModelRatio[name]
|
||||
ratio, ok := modelRatio[name]
|
||||
if !ok {
|
||||
SysError("model ratio not found: " + name)
|
||||
return 30
|
||||
@@ -149,22 +190,15 @@ func GetModelRatio(name string) float64 {
|
||||
|
||||
func GetCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-3.5") {
|
||||
if strings.HasSuffix(name, "0125") {
|
||||
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
|
||||
// https://openai.com/blog/new-embedding-models-and-api-updates
|
||||
// Updated GPT-3.5 Turbo model and lower pricing
|
||||
return 3
|
||||
}
|
||||
if strings.HasSuffix(name, "1106") {
|
||||
return 2
|
||||
}
|
||||
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
|
||||
// TODO: clear this after 2023-12-11
|
||||
now := time.Now()
|
||||
// https://platform.openai.com/docs/models/continuous-model-upgrades
|
||||
// if after 2023-12-11, use 2
|
||||
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
|
||||
return 2
|
||||
}
|
||||
}
|
||||
return 1.333333
|
||||
return 4.0 / 3.0
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") {
|
||||
if strings.HasSuffix(name, "preview") {
|
||||
@@ -173,10 +207,21 @@ func GetCompletionRatio(name string) float64 {
|
||||
return 2
|
||||
}
|
||||
if strings.HasPrefix(name, "claude-instant-1") {
|
||||
return 3.38
|
||||
return 3
|
||||
} else if strings.HasPrefix(name, "claude-2") {
|
||||
return 3
|
||||
} else if strings.HasPrefix(name, "claude-3") {
|
||||
return 5
|
||||
}
|
||||
if strings.HasPrefix(name, "claude-2") {
|
||||
return 2.965517
|
||||
if strings.HasPrefix(name, "mistral-") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "gemini-") {
|
||||
return 3
|
||||
}
|
||||
switch name {
|
||||
case "llama2-70b-4096":
|
||||
return 0.8 / 0.7
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
@@ -18,9 +18,8 @@ func InitRedisClient() (err error) {
|
||||
return nil
|
||||
}
|
||||
if os.Getenv("SYNC_FREQUENCY") == "" {
|
||||
RedisEnabled = false
|
||||
SysLog("SYNC_FREQUENCY not set, Redis is disabled")
|
||||
return nil
|
||||
SysLog("SYNC_FREQUENCY not set, use default value 60")
|
||||
SyncFrequency = 60
|
||||
}
|
||||
SysLog("Redis is enabled")
|
||||
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
|
||||
@@ -73,6 +72,35 @@ func RedisDel(key string) error {
|
||||
}
|
||||
|
||||
func RedisDecrease(key string, value int64) error {
|
||||
ctx := context.Background()
|
||||
return RDB.DecrBy(ctx, key, value).Err()
|
||||
|
||||
// 检查键的剩余生存时间
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil {
|
||||
// 失败则尝试直接减少
|
||||
return RDB.DecrBy(context.Background(), key, value).Err()
|
||||
}
|
||||
|
||||
// 如果剩余生存时间大于0,则进行减少操作
|
||||
if ttl > 0 {
|
||||
ctx := context.Background()
|
||||
// 开始一个Redis事务
|
||||
txn := RDB.TxPipeline()
|
||||
|
||||
// 减少余额
|
||||
decrCmd := txn.DecrBy(ctx, key, value)
|
||||
if err := decrCmd.Err(); err != nil {
|
||||
return err // 如果减少失败,则直接返回错误
|
||||
}
|
||||
|
||||
// 重新设置过期时间,使用原来的过期时间
|
||||
txn.Expire(ctx, key, ttl)
|
||||
|
||||
// 执行事务
|
||||
_, err = txn.Exec(ctx)
|
||||
return err
|
||||
} else {
|
||||
_ = RedisDel(key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
50
common/str.go
Normal file
50
common/str.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package common
|
||||
|
||||
func SundaySearch(text string, pattern string) bool {
|
||||
// 计算偏移表
|
||||
offset := make(map[rune]int)
|
||||
for i, c := range pattern {
|
||||
offset[c] = len(pattern) - i
|
||||
}
|
||||
|
||||
// 文本串长度和模式串长度
|
||||
n, m := len(text), len(pattern)
|
||||
|
||||
// 主循环,i表示当前对齐的文本串位置
|
||||
for i := 0; i <= n-m; {
|
||||
// 检查子串
|
||||
j := 0
|
||||
for j < m && text[i+j] == pattern[j] {
|
||||
j++
|
||||
}
|
||||
// 如果完全匹配,返回匹配位置
|
||||
if j == m {
|
||||
return true
|
||||
}
|
||||
|
||||
// 如果还有剩余字符,则检查下一位字符在偏移表中的值
|
||||
if i+m < n {
|
||||
next := rune(text[i+m])
|
||||
if val, ok := offset[next]; ok {
|
||||
i += val // 存在于偏移表中,进行跳跃
|
||||
} else {
|
||||
i += len(pattern) + 1 // 不存在于偏移表中,跳过整个模式串长度
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return false // 如果没有找到匹配,返回-1
|
||||
}
|
||||
|
||||
func RemoveDuplicate(s []string) []string {
|
||||
result := make([]string, 0, len(s))
|
||||
temp := map[string]struct{}{}
|
||||
for _, item := range s {
|
||||
if _, ok := temp[item]; !ok {
|
||||
temp[item] = struct{}{}
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -230,9 +230,14 @@ func StringsContains(strs []string, str string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// []byte only read, panic on append
|
||||
// StringToByteSlice []byte only read, panic on append
|
||||
func StringToByteSlice(s string) []byte {
|
||||
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
|
||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||
}
|
||||
|
||||
func RandomSleep() {
|
||||
// Sleep for 0-3000 ms
|
||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
44
constant/midjourney.go
Normal file
44
constant/midjourney.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package constant
|
||||
|
||||
var MjNotifyEnabled = false
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
MjRequestError = 4
|
||||
)
|
||||
|
||||
const (
|
||||
MjActionImagine = "IMAGINE"
|
||||
MjActionDescribe = "DESCRIBE"
|
||||
MjActionBlend = "BLEND"
|
||||
MjActionUpscale = "UPSCALE"
|
||||
MjActionVariation = "VARIATION"
|
||||
MjActionReRoll = "REROLL"
|
||||
MjActionInPaint = "INPAINT"
|
||||
MjActionModal = "MODAL"
|
||||
MjActionZoom = "ZOOM"
|
||||
MjActionCustomZoom = "CUSTOM_ZOOM"
|
||||
MjActionShorten = "SHORTEN"
|
||||
MjActionHighVariation = "HIGH_VARIATION"
|
||||
MjActionLowVariation = "LOW_VARIATION"
|
||||
MjActionPan = "PAN"
|
||||
MjActionSwapFace = "SWAP_FACE"
|
||||
)
|
||||
|
||||
var MidjourneyModel2Action = map[string]string{
|
||||
"mj_imagine": MjActionImagine,
|
||||
"mj_describe": MjActionDescribe,
|
||||
"mj_blend": MjActionBlend,
|
||||
"mj_upscale": MjActionUpscale,
|
||||
"mj_variation": MjActionVariation,
|
||||
"mj_reroll": MjActionReRoll,
|
||||
"mj_modal": MjActionModal,
|
||||
"mj_inpaint": MjActionInPaint,
|
||||
"mj_zoom": MjActionZoom,
|
||||
"mj_custom_zoom": MjActionCustomZoom,
|
||||
"mj_shorten": MjActionShorten,
|
||||
"mj_high_variation": MjActionHighVariation,
|
||||
"mj_low_variation": MjActionLowVariation,
|
||||
"mj_pan": MjActionPan,
|
||||
"swap_face": MjActionSwapFace,
|
||||
}
|
||||
43
constant/sensitive.go
Normal file
43
constant/sensitive.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package constant
|
||||
|
||||
import "strings"
|
||||
|
||||
var CheckSensitiveEnabled = true
|
||||
var CheckSensitiveOnPromptEnabled = true
|
||||
|
||||
//var CheckSensitiveOnCompletionEnabled = true
|
||||
|
||||
// StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词
|
||||
var StopOnSensitiveEnabled = true
|
||||
|
||||
// StreamCacheQueueLength 流模式缓存队列长度,0表示无缓存
|
||||
var StreamCacheQueueLength = 0
|
||||
|
||||
// SensitiveWords 敏感词
|
||||
// var SensitiveWords []string
|
||||
var SensitiveWords = []string{
|
||||
"test",
|
||||
}
|
||||
|
||||
func SensitiveWordsToString() string {
|
||||
return strings.Join(SensitiveWords, "\n")
|
||||
}
|
||||
|
||||
func SensitiveWordsFromString(s string) {
|
||||
SensitiveWords = []string{}
|
||||
sw := strings.Split(s, "\n")
|
||||
for _, w := range sw {
|
||||
w = strings.TrimSpace(w)
|
||||
if w != "" {
|
||||
SensitiveWords = append(SensitiveWords, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ShouldCheckPromptSensitive() bool {
|
||||
return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled
|
||||
}
|
||||
|
||||
//func ShouldCheckCompletionSensitive() bool {
|
||||
// return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
|
||||
//}
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
)
|
||||
|
||||
@@ -27,7 +28,7 @@ func GetSubscription(c *gin.Context) {
|
||||
expiredTime = 0
|
||||
}
|
||||
if err != nil {
|
||||
openAIError := OpenAIError{
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: "upstream_error",
|
||||
}
|
||||
@@ -69,7 +70,7 @@ func GetUsage(c *gin.Context) {
|
||||
quota, err = model.GetUserUsedQuota(userId)
|
||||
}
|
||||
if err != nil {
|
||||
openAIError := OpenAIError{
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: "new_api_error",
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
||||
for k := range headers {
|
||||
req.Header.Add(k, headers.Get(k))
|
||||
}
|
||||
res, err := httpClient.Do(req)
|
||||
res, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -213,10 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
return 0, errors.New("尚未实现")
|
||||
case common.ChannelTypeCustom:
|
||||
baseURL = channel.GetBaseURL()
|
||||
case common.ChannelTypeCloseAI:
|
||||
return updateChannelCloseAIBalance(channel)
|
||||
case common.ChannelTypeOpenAISB:
|
||||
return updateChannelOpenAISBBalance(channel)
|
||||
//case common.ChannelTypeOpenAISB:
|
||||
// return updateChannelOpenAISBBalance(channel)
|
||||
case common.ChannelTypeAIProxy:
|
||||
return updateChannelAIProxyBalance(channel)
|
||||
case common.ChannelTypeAPI2GPT:
|
||||
@@ -310,7 +309,7 @@ func updateAllChannelsBalance() error {
|
||||
} else {
|
||||
// err is nil & balance <= 0 means quota is used up
|
||||
if balance <= 0 {
|
||||
disableChannel(channel.Id, channel.Name, "余额不足")
|
||||
service.DisableChannel(channel.Id, channel.Name, "余额不足")
|
||||
}
|
||||
}
|
||||
time.Sleep(common.RequestInterval)
|
||||
|
||||
@@ -5,9 +5,17 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -15,88 +23,98 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/v1/chat/completions"},
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
switch channel.Type {
|
||||
case common.ChannelTypePaLM:
|
||||
fallthrough
|
||||
case common.ChannelTypeAnthropic:
|
||||
fallthrough
|
||||
case common.ChannelTypeBaidu:
|
||||
fallthrough
|
||||
case common.ChannelTypeZhipu:
|
||||
fallthrough
|
||||
case common.ChannelTypeAli:
|
||||
fallthrough
|
||||
case common.ChannelType360:
|
||||
fallthrough
|
||||
case common.ChannelTypeGemini:
|
||||
fallthrough
|
||||
case common.ChannelTypeXunfei:
|
||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
||||
case common.ChannelTypeAzure:
|
||||
if request.Model == "" {
|
||||
request.Model = "gpt-35-turbo"
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
|
||||
}
|
||||
}()
|
||||
default:
|
||||
if request.Model == "" {
|
||||
request.Model = "gpt-3.5-turbo"
|
||||
}
|
||||
}
|
||||
baseUrl := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseUrl = channel.GetBaseURL()
|
||||
}
|
||||
requestURL := getFullRequestURL(baseUrl, "/v1/chat/completions", channel.Type)
|
||||
|
||||
if channel.Type == common.ChannelTypeAzure {
|
||||
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
apiType := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", channel.Key)
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var response TextResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if response.Usage.CompletionTokens == 0 {
|
||||
if response.Error.Message == "" {
|
||||
response.Error.Message = "补全 tokens 非预期返回 0"
|
||||
if testModel == "" {
|
||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||
testModel = *channel.TestModel
|
||||
} else {
|
||||
testModel = adaptor.GetModelList()[0]
|
||||
}
|
||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
||||
}
|
||||
request := buildTestRequest()
|
||||
request.Model = testModel
|
||||
meta.UpstreamModelName = testModel
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||
|
||||
adaptor.Init(meta, *request)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err := relaycommon.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
||||
}
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
||||
}
|
||||
if usage == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
}
|
||||
result := w.Result()
|
||||
// print result.Body
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func buildTestRequest() *ChatRequest {
|
||||
testRequest := &ChatRequest{
|
||||
func buildTestRequest() *dto.GeneralOpenAIRequest {
|
||||
testRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: "", // this will be set later
|
||||
MaxTokens: 1,
|
||||
Stream: false,
|
||||
}
|
||||
content, _ := json.Marshal("hi")
|
||||
testMessage := Message{
|
||||
testMessage := dto.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
}
|
||||
@@ -113,7 +131,6 @@ func TestChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
testModel := c.Param("model")
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -122,12 +139,9 @@ func TestChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
testRequest := buildTestRequest()
|
||||
if testModel != "" {
|
||||
testRequest.Model = testModel
|
||||
}
|
||||
testModel := c.Query("model")
|
||||
tik := time.Now()
|
||||
err, _ = testChannel(channel, *testRequest)
|
||||
err, _ = testChannel(channel, testModel)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
go channel.UpdateResponseTime(milliseconds)
|
||||
@@ -151,31 +165,6 @@ func TestChannel(c *gin.Context) {
|
||||
var testAllChannelsLock sync.Mutex
|
||||
var testAllChannelsRunning bool = false
|
||||
|
||||
// disable & notify
|
||||
func disableChannel(channelId int, channelName string, reason string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
func enableChannel(channelId int, channelName string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
func notifyRootUser(subject string, content string) {
|
||||
if common.RootUserEmail == "" {
|
||||
common.RootUserEmail = model.GetRootUserEmail()
|
||||
}
|
||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func testAllChannels(notify bool) error {
|
||||
if common.RootUserEmail == "" {
|
||||
common.RootUserEmail = model.GetRootUserEmail()
|
||||
@@ -191,7 +180,6 @@ func testAllChannels(notify bool) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
testRequest := buildTestRequest()
|
||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||
if disableThreshold == 0 {
|
||||
disableThreshold = 10000000 // a impossible value
|
||||
@@ -200,7 +188,7 @@ func testAllChannels(notify bool) error {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiErr := testChannel(channel, *testRequest)
|
||||
err, openaiErr := testChannel(channel, "")
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
@@ -217,11 +205,11 @@ func testAllChannels(notify bool) error {
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) && ban {
|
||||
disableChannel(channel.Id, channel.Name, err.Error())
|
||||
if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
|
||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
|
||||
enableChannel(channel.Id, channel.Name)
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
|
||||
service.EnableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
|
||||
@@ -54,8 +54,9 @@ func FixChannelsAbilities(c *gin.Context) {
|
||||
func SearchChannels(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
modelKeyword := c.Query("model")
|
||||
//idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
channels, err := model.SearchChannels(keyword, group)
|
||||
channels, err := model.SearchChannels(keyword, group, modelKeyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -10,9 +10,13 @@ import (
|
||||
|
||||
func GetAllLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
@@ -20,7 +24,7 @@ func GetAllLogs(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
|
||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -38,16 +42,23 @@ func GetAllLogs(c *gin.Context) {
|
||||
|
||||
func GetUserLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -10,143 +10,14 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*func UpdateMidjourneyTask() {
|
||||
//revocer
|
||||
//imageModel := "midjourney"
|
||||
ctx := context.TODO()
|
||||
imageModel := "midjourney"
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("UpdateMidjourneyTask panic: %v", err)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
tasks := model.GetAllUnFinishTasks()
|
||||
if len(tasks) != 0 {
|
||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||
for _, task := range tasks {
|
||||
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
|
||||
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
|
||||
task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
|
||||
task.Status = "FAILURE"
|
||||
task.Progress = "100%"
|
||||
err := task.Update()
|
||||
if err != nil {
|
||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
|
||||
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
|
||||
|
||||
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
|
||||
if err != nil {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 5
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
//req.Header.Set("Authorization", "Bearer midjourney-proxy")
|
||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask error: %v", err)
|
||||
continue
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
log.Printf("responseBody: %s", string(responseBody))
|
||||
var responseItem Midjourney
|
||||
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
|
||||
err = json.Unmarshal(responseBody, &responseItem)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
|
||||
var responseWithoutStatus MidjourneyWithoutStatus
|
||||
var responseStatus MidjourneyStatus
|
||||
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
|
||||
err2 := json.Unmarshal(responseBody, &responseStatus)
|
||||
if err1 == nil && err2 == nil {
|
||||
jsonData, err3 := json.Marshal(responseWithoutStatus)
|
||||
if err3 != nil {
|
||||
log.Printf("UpdateMidjourneyTask error1: %v", err3)
|
||||
continue
|
||||
}
|
||||
err4 := json.Unmarshal(jsonData, &responseStatus)
|
||||
if err4 != nil {
|
||||
log.Printf("UpdateMidjourneyTask error2: %v", err4)
|
||||
continue
|
||||
}
|
||||
responseItem.Status = strconv.Itoa(responseStatus.Status)
|
||||
} else {
|
||||
log.Printf("UpdateMidjourneyTask error3: %v", err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
log.Printf("UpdateMidjourneyTask error4: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
task.Code = 1
|
||||
task.Progress = responseItem.Progress
|
||||
task.PromptEn = responseItem.PromptEn
|
||||
task.State = responseItem.State
|
||||
task.SubmitTime = responseItem.SubmitTime
|
||||
task.StartTime = responseItem.StartTime
|
||||
task.FinishTime = responseItem.FinishTime
|
||||
task.ImageUrl = responseItem.ImageUrl
|
||||
task.Status = responseItem.Status
|
||||
task.FailReason = responseItem.FailReason
|
||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
||||
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
if err != nil {
|
||||
log.Println("error update user quota cache: " + err.Error())
|
||||
} else {
|
||||
modelRatio := common.GetModelRatio(imageModel)
|
||||
groupRatio := common.GetGroupRatio("default")
|
||||
ratio := modelRatio * groupRatio
|
||||
quota := int(ratio * 1 * 1000)
|
||||
if quota != 0 {
|
||||
err := model.IncreaseUserQuota(task.UserId, quota)
|
||||
if err != nil {
|
||||
log.Println("fail to increase user quota")
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask error5: %v", err)
|
||||
}
|
||||
log.Printf("UpdateMidjourneyTask success: %v", task)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
func UpdateMidjourneyTaskBulk() {
|
||||
//imageModel := "midjourney"
|
||||
ctx := context.TODO()
|
||||
@@ -221,17 +92,21 @@ func UpdateMidjourneyTaskBulk() {
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||
resp, err := httpClient.Do(req)
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
continue
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
continue
|
||||
}
|
||||
var responseItems []Midjourney
|
||||
var responseItems []dto.MidjourneyDto
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
@@ -243,10 +118,16 @@ func UpdateMidjourneyTaskBulk() {
|
||||
|
||||
for _, responseItem := range responseItems {
|
||||
task := taskM[responseItem.MjId]
|
||||
|
||||
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
|
||||
// 如果时间超过一小时,且进度不是100%,则认为任务失败
|
||||
if useTime > 3600000 && task.Progress != "100%" {
|
||||
responseItem.FailReason = "上游任务超时(超过1小时)"
|
||||
responseItem.Status = "FAILURE"
|
||||
}
|
||||
if !checkMjTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
|
||||
task.Code = 1
|
||||
task.Progress = responseItem.Progress
|
||||
task.PromptEn = responseItem.PromptEn
|
||||
@@ -257,7 +138,16 @@ func UpdateMidjourneyTaskBulk() {
|
||||
task.ImageUrl = responseItem.ImageUrl
|
||||
task.Status = responseItem.Status
|
||||
task.FailReason = responseItem.FailReason
|
||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
||||
if responseItem.Properties != nil {
|
||||
propertiesStr, _ := json.Marshal(responseItem.Properties)
|
||||
task.Properties = string(propertiesStr)
|
||||
}
|
||||
if responseItem.Buttons != nil {
|
||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||
task.Buttons = string(buttonStr)
|
||||
}
|
||||
|
||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
@@ -284,7 +174,7 @@ func UpdateMidjourneyTaskBulk() {
|
||||
}
|
||||
}
|
||||
|
||||
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
|
||||
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
|
||||
if oldTask.Code != 1 {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -5,21 +5,41 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestStatus(c *gin.Context) {
|
||||
err := model.PingDB()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"success": false,
|
||||
"message": "数据库连接失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Server is running",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
@@ -27,6 +47,7 @@ func GetStatus(c *gin.Context) {
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": common.ServerAddress,
|
||||
"price": common.Price,
|
||||
"min_topup": common.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
@@ -38,6 +59,9 @@ func GetStatus(c *gin.Context) {
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": common.PayAddress != "" && common.EpayId != "" && common.EpayKey != "",
|
||||
"mj_notify_enabled": constant.MjNotifyEnabled,
|
||||
},
|
||||
})
|
||||
return
|
||||
@@ -96,10 +120,20 @@ func SendEmailVerification(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的邮箱地址",
|
||||
})
|
||||
return
|
||||
}
|
||||
localPart := parts[0]
|
||||
domainPart := parts[1]
|
||||
if common.EmailDomainRestrictionEnabled {
|
||||
allowed := false
|
||||
for _, domain := range common.EmailDomainWhitelist {
|
||||
if strings.HasSuffix(email, "@"+domain) {
|
||||
if domainPart == domain {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
@@ -107,11 +141,22 @@ func SendEmailVerification(c *gin.Context) {
|
||||
if !allowed {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
|
||||
"message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if common.EmailAliasRestrictionEnabled {
|
||||
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1
|
||||
if containsSpecialSymbols {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if model.IsEmailAlreadyTaken(email) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -2,8 +2,16 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel/ai360"
|
||||
"one-api/relay/channel/moonshot"
|
||||
"one-api/relay/channel/lingyiwanwu"
|
||||
relayconstant "one-api/relay/constant"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -53,574 +61,68 @@ func init() {
|
||||
IsBlocking: false,
|
||||
})
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
openAIModels = []OpenAIModels{
|
||||
{
|
||||
Id: "midjourney",
|
||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
||||
if i == relayconstant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
adaptor := relay.GetAdaptor(i)
|
||||
channelName := adaptor.GetChannelName()
|
||||
modelNames := adaptor.GetModelList()
|
||||
for _, modelName := range modelNames {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, modelName := range ai360.ModelList {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "Midjourney",
|
||||
Permission: permission,
|
||||
Root: "midjourney",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "dall-e-2",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "dall-e-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "dall-e-3",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "dall-e-3",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "whisper-1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "whisper-1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "tts-1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "tts-1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "tts-1-1106",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "tts-1-1106",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "tts-1-hd",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "tts-1-hd",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "tts-1-hd-1106",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "tts-1-hd-1106",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-3.5-turbo",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo-0301",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-3.5-turbo-0301",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo-0613",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-3.5-turbo-0613",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo-16k",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-3.5-turbo-16k",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo-16k-0613",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-3.5-turbo-16k-0613",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo-1106",
|
||||
Object: "model",
|
||||
Created: 1699593571,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-3.5-turbo-1106",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo-0125",
|
||||
Object: "model",
|
||||
Created: 1706232090,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-3.5-turbo-0125",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo-instruct",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-3.5-turbo-instruct",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-0314",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-0314",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-0613",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-0613",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-32k",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-32k",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-32k-0314",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-32k-0314",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-32k-0613",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-32k-0613",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-1106-preview",
|
||||
Object: "model",
|
||||
Created: 1699593571,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-1106-preview",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-0125-preview",
|
||||
Object: "model",
|
||||
Created: 1699593571,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-0125-preview",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-turbo-preview",
|
||||
Object: "model",
|
||||
Created: 1699593571,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-turbo-preview",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-vision-preview",
|
||||
Object: "model",
|
||||
Created: 1699593571,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-vision-preview",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-1106-vision-preview",
|
||||
Object: "model",
|
||||
Created: 1699593571,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-1106-vision-preview",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-embedding-3-small",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-embedding-ada-002",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-embedding-3-large",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-embedding-ada-002",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-embedding-ada-002",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-embedding-ada-002",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-davinci-003",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-davinci-003",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-davinci-002",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-davinci-002",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-curie-001",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-curie-001",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-babbage-001",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-babbage-001",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-ada-001",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-ada-001",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-moderation-latest",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-moderation-latest",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-moderation-stable",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-moderation-stable",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-davinci-edit-001",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "text-davinci-edit-001",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "code-davinci-edit-001",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "code-davinci-edit-001",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "babbage-002",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "babbage-002",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "davinci-002",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "davinci-002",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "claude-instant-1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "anthropic",
|
||||
Permission: permission,
|
||||
Root: "claude-instant-1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "claude-2",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "anthropic",
|
||||
Permission: permission,
|
||||
Root: "claude-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "ERNIE-Bot",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "baidu",
|
||||
Permission: permission,
|
||||
Root: "ERNIE-Bot",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "ERNIE-Bot-turbo",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "baidu",
|
||||
Permission: permission,
|
||||
Root: "ERNIE-Bot-turbo",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "ERNIE-Bot-4",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "baidu",
|
||||
Permission: permission,
|
||||
Root: "ERNIE-Bot-4",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "Embedding-V1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "baidu",
|
||||
Permission: permission,
|
||||
Root: "Embedding-V1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "PaLM-2",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google",
|
||||
Permission: permission,
|
||||
Root: "PaLM-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gemini-pro",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google",
|
||||
Permission: permission,
|
||||
Root: "gemini-pro",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gemini-pro-vision",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google",
|
||||
Permission: permission,
|
||||
Root: "gemini-pro-vision",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "chatglm_turbo",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "zhipu",
|
||||
Permission: permission,
|
||||
Root: "chatglm_turbo",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "chatglm_pro",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "zhipu",
|
||||
Permission: permission,
|
||||
Root: "chatglm_pro",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "chatglm_std",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "zhipu",
|
||||
Permission: permission,
|
||||
Root: "chatglm_std",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "chatglm_lite",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "zhipu",
|
||||
Permission: permission,
|
||||
Root: "chatglm_lite",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-turbo",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "qwen-turbo",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-plus",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "qwen-plus",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-embedding-v1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "text-embedding-v1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "SparkDesk",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "xunfei",
|
||||
Permission: permission,
|
||||
Root: "SparkDesk",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "360GPT_S2_V9",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
Created: 1626777600,
|
||||
OwnedBy: "360",
|
||||
Permission: permission,
|
||||
Root: "360GPT_S2_V9",
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "embedding-bert-512-v1",
|
||||
})
|
||||
}
|
||||
for _, modelName := range moonshot.ModelList {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "360",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "moonshot",
|
||||
Permission: permission,
|
||||
Root: "embedding-bert-512-v1",
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "embedding_s1_v1",
|
||||
})
|
||||
}
|
||||
for _, modelName := range lingyiwanwu.ModelList {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "360",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "lingyiwanwu",
|
||||
Permission: permission,
|
||||
Root: "embedding_s1_v1",
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "semantic_similarity_s1_v1",
|
||||
})
|
||||
}
|
||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "360",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
Permission: permission,
|
||||
Root: "semantic_similarity_s1_v1",
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "hunyuan",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "tencent",
|
||||
Permission: permission,
|
||||
Root: "hunyuan",
|
||||
Parent: nil,
|
||||
},
|
||||
})
|
||||
}
|
||||
openAIModelsMap = make(map[string]OpenAIModels)
|
||||
for _, model := range openAIModels {
|
||||
@@ -629,6 +131,29 @@ func init() {
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
models := model.GetGroupModels(user.Group)
|
||||
userOpenAiModels := make([]OpenAIModels, 0)
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"object": "list",
|
||||
"data": userOpenAiModels,
|
||||
})
|
||||
}
|
||||
|
||||
func ChannelListModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"object": "list",
|
||||
"data": openAIModels,
|
||||
@@ -640,7 +165,7 @@ func RetrieveModel(c *gin.Context) {
|
||||
if model, ok := openAIModelsMap[modelId]; ok {
|
||||
c.JSON(200, model)
|
||||
} else {
|
||||
openAIError := OpenAIError{
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||
Type: "invalid_request_error",
|
||||
Param: "model",
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
||||
|
||||
type AIProxyLibraryRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
LibraryId string `json:"libraryId"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type AIProxyLibraryError struct {
|
||||
ErrCode int `json:"errCode"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type AIProxyLibraryDocument struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type AIProxyLibraryResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Answer string `json:"answer"`
|
||||
Documents []AIProxyLibraryDocument `json:"documents"`
|
||||
AIProxyLibraryError
|
||||
}
|
||||
|
||||
type AIProxyLibraryStreamResponse struct {
|
||||
Content string `json:"content"`
|
||||
Finish bool `json:"finish"`
|
||||
Model string `json:"model"`
|
||||
Documents []AIProxyLibraryDocument `json:"documents"`
|
||||
}
|
||||
|
||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
||||
query := ""
|
||||
if len(request.Messages) != 0 {
|
||||
query = string(request.Messages[len(request.Messages)-1].Content)
|
||||
}
|
||||
return &AIProxyLibraryRequest{
|
||||
Model: request.Model,
|
||||
Stream: request.Stream,
|
||||
Query: query,
|
||||
}
|
||||
}
|
||||
|
||||
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
|
||||
if len(documents) == 0 {
|
||||
return ""
|
||||
}
|
||||
content := "\n\n参考文档:\n"
|
||||
for i, document := range documents {
|
||||
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
|
||||
content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents))
|
||||
choice := OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Id: common.GetUUID(),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
||||
choice.FinishReason = &stopFinishReason
|
||||
return &ChatCompletionsStreamResponse{
|
||||
Id: common.GetUUID(),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
}
|
||||
|
||||
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = response.Content
|
||||
return &ChatCompletionsStreamResponse{
|
||||
Id: common.GetUUID(),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: response.Model,
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
}
|
||||
|
||||
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var usage Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
var documents []AIProxyLibraryDocument
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if len(AIProxyLibraryResponse.Documents) != 0 {
|
||||
documents = AIProxyLibraryResponse.Documents
|
||||
}
|
||||
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
response := documentsAIProxyLibrary(documents)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var AIProxyLibraryResponse AIProxyLibraryResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if AIProxyLibraryResponse.ErrCode != 0 {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: AIProxyLibraryResponse.Message,
|
||||
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
||||
Code: AIProxyLibraryResponse.ErrCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
@@ -1,221 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
}
|
||||
|
||||
func stopReasonClaude2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
}
|
||||
|
||||
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
MaxTokensToSample: textRequest.MaxTokens,
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
Stream: textRequest.Stream,
|
||||
}
|
||||
if claudeRequest.MaxTokensToSample == 0 {
|
||||
claudeRequest.MaxTokensToSample = 1000000
|
||||
}
|
||||
prompt := ""
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "user" {
|
||||
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
|
||||
} else if message.Role == "assistant" {
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||
} else if message.Role == "system" {
|
||||
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
|
||||
}
|
||||
}
|
||||
prompt += "\n\nAssistant:"
|
||||
claudeRequest.Prompt = prompt
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
var response ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
|
||||
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
||||
choice := OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
|
||||
return i + 4, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if !strings.HasPrefix(data, "event: completion") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var claudeResponse ClaudeResponse
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
responseText += claudeResponse.Completion
|
||||
response := streamResponseClaude2OpenAI(&claudeResponse)
|
||||
response.Id = responseId
|
||||
response.Created = createdTime
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var claudeResponse ClaudeResponse
|
||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
||||
completionTokens := countTokenText(claudeResponse.Completion, model)
|
||||
usage := Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
@@ -1,650 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Midjourney struct {
|
||||
MjId string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submitTime"`
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
}
|
||||
|
||||
type MidjourneyStatus struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
type MidjourneyWithoutStatus struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
}
|
||||
|
||||
var DefaultModelPrice = map[string]float64{
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
}
|
||||
|
||||
func RelayMidjourneyImage(c *gin.Context) {
|
||||
taskId := c.Param("id")
|
||||
midjourneyTask := model.GetByOnlyMJId(taskId)
|
||||
if midjourneyTask == nil {
|
||||
c.JSON(400, gin.H{
|
||||
"error": "midjourney_task_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(midjourneyTask.ImageUrl)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "http_get_image_failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
c.JSON(resp.StatusCode, gin.H{
|
||||
"error": string(responseBody),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 从Content-Type头获取MIME类型
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
// 如果无法确定内容类型,则默认为jpeg
|
||||
contentType = "image/jpeg"
|
||||
}
|
||||
// 设置响应的内容类型
|
||||
c.Writer.Header().Set("Content-Type", contentType)
|
||||
// 将图片流式传输到响应体
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
log.Println("Failed to stream image:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
||||
var midjRequest Midjourney
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "bind_request_body_failed",
|
||||
Properties: nil,
|
||||
Result: "",
|
||||
}
|
||||
}
|
||||
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
|
||||
if midjourneyTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "midjourney_task_not_found",
|
||||
Properties: nil,
|
||||
Result: "",
|
||||
}
|
||||
}
|
||||
midjourneyTask.Progress = midjRequest.Progress
|
||||
midjourneyTask.PromptEn = midjRequest.PromptEn
|
||||
midjourneyTask.State = midjRequest.State
|
||||
midjourneyTask.SubmitTime = midjRequest.SubmitTime
|
||||
midjourneyTask.StartTime = midjRequest.StartTime
|
||||
midjourneyTask.FinishTime = midjRequest.FinishTime
|
||||
midjourneyTask.ImageUrl = midjRequest.ImageUrl
|
||||
midjourneyTask.Status = midjRequest.Status
|
||||
midjourneyTask.FailReason = midjRequest.FailReason
|
||||
err = midjourneyTask.Update()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "update_midjourney_task_failed",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
|
||||
midjourneyTask.MjId = originTask.MjId
|
||||
midjourneyTask.Progress = originTask.Progress
|
||||
midjourneyTask.PromptEn = originTask.PromptEn
|
||||
midjourneyTask.State = originTask.State
|
||||
midjourneyTask.SubmitTime = originTask.SubmitTime
|
||||
midjourneyTask.StartTime = originTask.StartTime
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" {
|
||||
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
}
|
||||
midjourneyTask.Status = originTask.Status
|
||||
midjourneyTask.FailReason = originTask.FailReason
|
||||
midjourneyTask.Action = originTask.Action
|
||||
midjourneyTask.Description = originTask.Description
|
||||
midjourneyTask.Prompt = originTask.Prompt
|
||||
return
|
||||
}
|
||||
|
||||
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
userId := c.GetInt("id")
|
||||
var err error
|
||||
var respBody []byte
|
||||
switch relayMode {
|
||||
case RelayModeMidjourneyTaskFetch:
|
||||
taskId := c.Param("id")
|
||||
originTask := model.GetByMJId(userId, taskId)
|
||||
if originTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_no_found",
|
||||
}
|
||||
}
|
||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
||||
respBody, err = json.Marshal(midjourneyTask)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
case RelayModeMidjourneyTaskFetchByCondition:
|
||||
var condition = struct {
|
||||
IDs []string `json:"ids"`
|
||||
}{}
|
||||
err = c.BindJSON(&condition)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "do_request_failed",
|
||||
}
|
||||
}
|
||||
var tasks []Midjourney
|
||||
if len(condition.IDs) != 0 {
|
||||
originTasks := model.GetByMJIds(userId, condition.IDs)
|
||||
for _, originTask := range originTasks {
|
||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
||||
tasks = append(tasks, midjourneyTask)
|
||||
}
|
||||
}
|
||||
if tasks == nil {
|
||||
tasks = make([]Midjourney, 0)
|
||||
}
|
||||
respBody, err = json.Marshal(tasks)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "copy_response_body_failed",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// type 1 根据 mode 价格不同
|
||||
MJSubmitActionImagine = "IMAGINE"
|
||||
MJSubmitActionVariation = "VARIATION" //变换
|
||||
MJSubmitActionBlend = "BLEND" //混图
|
||||
|
||||
MJSubmitActionReroll = "REROLL" //重新生成
|
||||
// type 2 固定价格
|
||||
MJSubmitActionDescribe = "DESCRIBE"
|
||||
MJSubmitActionUpscale = "UPSCALE" // 放大
|
||||
)
|
||||
|
||||
func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
imageModel := "midjourney"
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
var midjRequest MidjourneyRequest
|
||||
if consumeQuota {
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "bind_request_body_failed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||
if midjRequest.Prompt == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "prompt_is_required",
|
||||
}
|
||||
}
|
||||
midjRequest.Action = "IMAGINE"
|
||||
} else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||
midjRequest.Action = "DESCRIBE"
|
||||
} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||
midjRequest.Action = "BLEND"
|
||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||
mjId := ""
|
||||
if relayMode == RelayModeMidjourneyChange {
|
||||
if midjRequest.TaskId == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "taskId_is_required",
|
||||
}
|
||||
} else if midjRequest.Action == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "action_is_required",
|
||||
}
|
||||
} else if midjRequest.Index == 0 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "index_can_only_be_1_2_3_4",
|
||||
}
|
||||
}
|
||||
//action = midjRequest.Action
|
||||
mjId = midjRequest.TaskId
|
||||
} else if relayMode == RelayModeMidjourneySimpleChange {
|
||||
if midjRequest.Content == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "content_is_required",
|
||||
}
|
||||
}
|
||||
params := convertSimpleChangeParams(midjRequest.Content)
|
||||
if params == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "content_parse_failed",
|
||||
}
|
||||
}
|
||||
mjId = params.ID
|
||||
midjRequest.Action = params.Action
|
||||
}
|
||||
|
||||
originTask := model.GetByMJId(userId, mjId)
|
||||
if originTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_no_found",
|
||||
}
|
||||
} else if originTask.Action == "UPSCALE" {
|
||||
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "upscale_task_can_not_be_change",
|
||||
}
|
||||
} else if originTask.Status != "SUCCESS" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_status_is_not_success",
|
||||
}
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, false)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "channel_not_found",
|
||||
}
|
||||
}
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
||||
}
|
||||
midjRequest.Prompt = originTask.Prompt
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_model_mapping_failed",
|
||||
}
|
||||
}
|
||||
if modelMap[imageModel] != "" {
|
||||
imageModel = modelMap[imageModel]
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
log.Printf("fullRequestURL: %s", fullRequestURL)
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(midjRequest)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "marshal_text_request_failed",
|
||||
}
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
|
||||
modelPrice := common.GetModelPrice(mjAction, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if modelPrice == -1 {
|
||||
defaultPrice, ok := DefaultModelPrice[mjAction]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: err.Error(),
|
||||
}
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
|
||||
if consumeQuota && userQuota-quota < 0 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "quota_not_enough",
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "create_request_failed",
|
||||
}
|
||||
}
|
||||
//req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
//mjToken := ""
|
||||
//if c.Request.Header.Get("Authorization") != "" {
|
||||
// mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1]
|
||||
//}
|
||||
//req.Header.Set("Authorization", "Bearer midjourney-proxy")
|
||||
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
|
||||
// print request header
|
||||
log.Printf("request header: %s", req.Header)
|
||||
log.Printf("request body: %s", midjRequest.Prompt)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "do_request_failed",
|
||||
}
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_request_body_failed",
|
||||
}
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_request_body_failed",
|
||||
}
|
||||
}
|
||||
var midjResponse MidjourneyResponse
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if consumeQuota {
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
//if consumeQuota {
|
||||
//
|
||||
//}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "read_response_body_failed",
|
||||
}
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_response_body_failed",
|
||||
}
|
||||
}
|
||||
|
||||
err = json.Unmarshal(responseBody, &midjResponse)
|
||||
log.Printf("responseBody: %s", string(responseBody))
|
||||
log.Printf("midjResponse: %v", midjResponse)
|
||||
if resp.StatusCode != 200 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
|
||||
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
||||
//1-提交成功
|
||||
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
|
||||
// 22-排队中 {"code":22,"description":"排队中,前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
|
||||
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
|
||||
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
|
||||
// other: 提交错误,description为错误描述
|
||||
midjourneyTask := &model.Midjourney{
|
||||
UserId: userId,
|
||||
Code: midjResponse.Code,
|
||||
Action: midjRequest.Action,
|
||||
MjId: midjResponse.Result,
|
||||
Prompt: midjRequest.Prompt,
|
||||
PromptEn: "",
|
||||
Description: midjResponse.Description,
|
||||
State: "",
|
||||
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||
StartTime: 0,
|
||||
FinishTime: 0,
|
||||
ImageUrl: "",
|
||||
Status: "",
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
Quota: quota,
|
||||
}
|
||||
|
||||
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
|
||||
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
|
||||
midjourneyTask.FailReason = midjResponse.Description
|
||||
consumeQuota = false
|
||||
}
|
||||
|
||||
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
|
||||
// 将 properties 转换为一个 map
|
||||
properties, ok := midjResponse.Properties.(map[string]interface{})
|
||||
if ok {
|
||||
imageUrl, ok1 := properties["imageUrl"].(string)
|
||||
status, ok2 := properties["status"].(string)
|
||||
if ok1 && ok2 {
|
||||
midjourneyTask.ImageUrl = imageUrl
|
||||
midjourneyTask.Status = status
|
||||
if status == "SUCCESS" {
|
||||
midjourneyTask.Progress = "100%"
|
||||
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
|
||||
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
|
||||
midjResponse.Code = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
//修改返回值
|
||||
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
|
||||
responseBody = []byte(newBody)
|
||||
}
|
||||
|
||||
err = midjourneyTask.Insert()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "insert_midjourney_task_failed",
|
||||
}
|
||||
}
|
||||
|
||||
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
|
||||
//修改返回值
|
||||
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
|
||||
responseBody = []byte(newBody)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "copy_response_body_failed",
|
||||
}
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_response_body_failed",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type taskChangeParams struct {
|
||||
ID string
|
||||
Action string
|
||||
Index int
|
||||
}
|
||||
|
||||
func convertSimpleChangeParams(content string) *taskChangeParams {
|
||||
split := strings.Split(content, " ")
|
||||
if len(split) != 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
action := strings.ToLower(split[1])
|
||||
changeParams := &taskChangeParams{}
|
||||
changeParams.ID = split[0]
|
||||
|
||||
if action[0] == 'u' {
|
||||
changeParams.Action = "UPSCALE"
|
||||
} else if action[0] == 'v' {
|
||||
changeParams.Action = "VARIATION"
|
||||
} else if action == "r" {
|
||||
changeParams.Action = "REROLL"
|
||||
return changeParams
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(action[1:2])
|
||||
if err != nil || index < 1 || index > 4 {
|
||||
return nil
|
||||
}
|
||||
changeParams.Index = index
|
||||
return changeParams
|
||||
}
|
||||
@@ -1,752 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeClaude
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
)
|
||||
|
||||
var httpClient *http.Client
|
||||
var impatientHTTPClient *http.Client
|
||||
|
||||
func init() {
|
||||
if common.RelayTimeout == 0 {
|
||||
httpClient = &http.Client{}
|
||||
} else {
|
||||
httpClient = &http.Client{
|
||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
impatientHTTPClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := time.Now()
|
||||
var textRequest GeneralOpenAIRequest
|
||||
|
||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
if relayMode == RelayModeModerations && textRequest.Model == "" {
|
||||
textRequest.Model = "text-moderation-latest"
|
||||
}
|
||||
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
|
||||
textRequest.Model = c.Param("model")
|
||||
}
|
||||
// request validation
|
||||
if textRequest.Model == "" {
|
||||
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
switch relayMode {
|
||||
case RelayModeCompletions:
|
||||
if textRequest.Prompt == "" {
|
||||
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
case RelayModeChatCompletions:
|
||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
case RelayModeEmbeddings:
|
||||
case RelayModeModerations:
|
||||
if textRequest.Input == "" {
|
||||
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
case RelayModeEdits:
|
||||
if textRequest.Instruction == "" {
|
||||
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
apiType := APITypeOpenAI
|
||||
switch channelType {
|
||||
case common.ChannelTypeAnthropic:
|
||||
apiType = APITypeClaude
|
||||
case common.ChannelTypeBaidu:
|
||||
apiType = APITypeBaidu
|
||||
case common.ChannelTypePaLM:
|
||||
apiType = APITypePaLM
|
||||
case common.ChannelTypeZhipu:
|
||||
apiType = APITypeZhipu
|
||||
case common.ChannelTypeAli:
|
||||
apiType = APITypeAli
|
||||
case common.ChannelTypeXunfei:
|
||||
apiType = APITypeXunfei
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
}
|
||||
requestURL := strings.Split(requestURL, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||
baseURL = c.GetString("base_url")
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
model_ := textRequest.Model
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
model_ = strings.TrimSuffix(model_, "-0301")
|
||||
model_ = strings.TrimSuffix(model_, "-0314")
|
||||
model_ = strings.TrimSuffix(model_, "-0613")
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
|
||||
}
|
||||
case APITypeClaude:
|
||||
fullRequestURL = "https://api.anthropic.com/v1/complete"
|
||||
if baseURL != "" {
|
||||
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
|
||||
}
|
||||
case APITypeBaidu:
|
||||
switch textRequest.Model {
|
||||
case "ERNIE-Bot":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Bot-turbo":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "ERNIE-Bot-4":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
var err error
|
||||
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
|
||||
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
fullRequestURL += "?access_token=" + apiKey
|
||||
case APITypePaLM:
|
||||
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
|
||||
if baseURL != "" {
|
||||
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
case APITypeGemini:
|
||||
requestBaseURL := "https://generativelanguage.googleapis.com"
|
||||
if baseURL != "" {
|
||||
requestBaseURL = baseURL
|
||||
}
|
||||
version := "v1beta"
|
||||
if c.GetString("api_version") != "" {
|
||||
version = c.GetString("api_version")
|
||||
}
|
||||
action := "generateContent"
|
||||
if textRequest.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
//log.Println(fullRequestURL)
|
||||
|
||||
case APITypeZhipu:
|
||||
method := "invoke"
|
||||
if textRequest.Stream {
|
||||
method = "sse-invoke"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||
case APITypeAli:
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
if relayMode == RelayModeEmbeddings {
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
}
|
||||
case APITypeTencent:
|
||||
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
|
||||
case APITypeAIProxyLibrary:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
||||
}
|
||||
var promptTokens int
|
||||
var completionTokens int
|
||||
switch relayMode {
|
||||
case RelayModeChatCompletions:
|
||||
promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
}
|
||||
case RelayModeCompletions:
|
||||
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case RelayModeModerations:
|
||||
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
|
||||
}
|
||||
modelPrice := common.GetModelPrice(textRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
if modelPrice == -1 {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
||||
}
|
||||
modelRatio = common.GetModelRatio(textRequest.Model)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota < 0 || userQuota-preConsumedQuota < 0 {
|
||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// 用户额度充足,判断令牌额度是否充足
|
||||
if !tokenUnlimited {
|
||||
// 非无限令牌,判断令牌额度是否充足
|
||||
tokenQuota := c.GetInt("token_quota")
|
||||
if tokenQuota > 100*preConsumedQuota {
|
||||
// 令牌额度充足,信任令牌
|
||||
preConsumedQuota = 0
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", userId, userQuota, tokenId, tokenQuota))
|
||||
}
|
||||
} else {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||
}
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
switch apiType {
|
||||
case APITypeClaude:
|
||||
claudeRequest := requestOpenAI2Claude(textRequest)
|
||||
jsonStr, err := json.Marshal(claudeRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeBaidu:
|
||||
var jsonData []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
|
||||
jsonData, err = json.Marshal(baiduEmbeddingRequest)
|
||||
default:
|
||||
baiduRequest := requestOpenAI2Baidu(textRequest)
|
||||
jsonData, err = json.Marshal(baiduRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
case APITypePaLM:
|
||||
palmRequest := requestOpenAI2PaLM(textRequest)
|
||||
jsonStr, err := json.Marshal(palmRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeGemini:
|
||||
geminiChatRequest := requestOpenAI2Gemini(textRequest)
|
||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeZhipu:
|
||||
zhipuRequest := requestOpenAI2Zhipu(textRequest)
|
||||
jsonStr, err := json.Marshal(zhipuRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeAli:
|
||||
var jsonStr []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
|
||||
jsonStr, err = json.Marshal(aliEmbeddingRequest)
|
||||
default:
|
||||
aliRequest := requestOpenAI2Ali(textRequest)
|
||||
jsonStr, err = json.Marshal(aliRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeTencent:
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
|
||||
}
|
||||
tencentRequest := requestOpenAI2Tencent(textRequest)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
jsonStr, err := json.Marshal(tencentRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
sign := getTencentSign(*tencentRequest, secretKey)
|
||||
c.Request.Header.Set("Authorization", sign)
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeAIProxyLibrary:
|
||||
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
|
||||
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
||||
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
var resp *http.Response
|
||||
isStream := textRequest.Stream
|
||||
|
||||
if apiType != APITypeXunfei { // cause xunfei use websocket
|
||||
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
// 设置GetBody函数,该函数返回一个新的io.ReadCloser,该io.ReadCloser返回与原始请求体相同的数据
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(requestBody), nil
|
||||
}
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", apiKey)
|
||||
} else {
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
if c.Request.Header.Get("OpenAI-Organization") != "" {
|
||||
req.Header.Set("OpenAI-Organization", c.Request.Header.Get("OpenAI-Organization"))
|
||||
}
|
||||
if channelType == common.ChannelTypeOpenRouter {
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
}
|
||||
case APITypeClaude:
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
case APITypeZhipu:
|
||||
token := getZhipuToken(apiKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
case APITypeAli:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if textRequest.Stream {
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
case APITypeTencent:
|
||||
req.Header.Set("Authorization", apiKey)
|
||||
case APITypeGemini:
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
if apiType != APITypeGemini {
|
||||
// 设置公共头部...
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
|
||||
resp, err = httpClient.Do(req)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
return relayErrorHandler(resp)
|
||||
}
|
||||
}
|
||||
|
||||
var textResponse TextResponse
|
||||
tokenName := c.GetString("token_name")
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
// c.Writer.Flush()
|
||||
go func() {
|
||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
||||
promptTokens = textResponse.Usage.PromptTokens
|
||||
completionTokens = textResponse.Usage.CompletionTokens
|
||||
|
||||
quota := 0
|
||||
if modelPrice == -1 {
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
var logContent string
|
||||
if modelPrice == -1 {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||
}
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(有疑问请联系管理员)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", userId, channelId, tokenId, textRequest.Model, preConsumedQuota))
|
||||
} else {
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
|
||||
logModel := textRequest.Model
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
||||
}
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), isStream)
|
||||
|
||||
//if quota != 0 {
|
||||
//
|
||||
//}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if isStream {
|
||||
err, responseText := openaiStreamHandler(c, resp, relayMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeClaude:
|
||||
if isStream {
|
||||
err, responseText := claudeStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeBaidu:
|
||||
if isStream {
|
||||
err, usage := baiduStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var usage *Usage
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
err, usage = baiduEmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = baiduHandler(c, resp)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypePaLM:
|
||||
if textRequest.Stream { // PaLM2 API does not support stream
|
||||
err, responseText := palmStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeGemini:
|
||||
if textRequest.Stream {
|
||||
err, responseText := geminiChatStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeZhipu:
|
||||
if isStream {
|
||||
err, usage := zhipuStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
// zhipu's API does not return prompt tokens & completion tokens
|
||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
||||
return nil
|
||||
} else {
|
||||
err, usage := zhipuHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
// zhipu's API does not return prompt tokens & completion tokens
|
||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
||||
return nil
|
||||
}
|
||||
case APITypeAli:
|
||||
if isStream {
|
||||
err, usage := aliStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var usage *Usage
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
err, usage = aliEmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = aliHandler(c, resp)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeXunfei:
|
||||
auth := c.Request.Header.Get("Authorization")
|
||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||
splits := strings.Split(auth, "|")
|
||||
if len(splits) != 3 {
|
||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var usage *Usage
|
||||
if isStream {
|
||||
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
case APITypeAIProxyLibrary:
|
||||
if isStream {
|
||||
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
err, usage := aiProxyLibraryHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeTencent:
|
||||
if isStream {
|
||||
err, responseText := tencentStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := tencentHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
@@ -1,420 +1,160 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCalls any `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type MediaMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl any `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type MessageImageUrl struct {
|
||||
Url string `json:"url"`
|
||||
Detail string `json:"detail"`
|
||||
}
|
||||
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
)
|
||||
|
||||
func (m Message) ParseContent() []MediaMessage {
|
||||
var contentList []MediaMessage
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeText,
|
||||
Text: stringContent,
|
||||
})
|
||||
return contentList
|
||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesGenerations:
|
||||
err = relay.RelayImageHelper(c, relayMode)
|
||||
case relayconstant.RelayModeAudioSpeech:
|
||||
fallthrough
|
||||
case relayconstant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case relayconstant.RelayModeAudioTranscription:
|
||||
err = relay.AudioHelper(c, relayMode)
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
var arrayContent []json.RawMessage
|
||||
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
|
||||
for _, contentItem := range arrayContent {
|
||||
var contentMap map[string]any
|
||||
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
|
||||
continue
|
||||
}
|
||||
switch contentMap["type"] {
|
||||
case ContentTypeText:
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeText,
|
||||
Text: subStr,
|
||||
})
|
||||
}
|
||||
case ContentTypeImageURL:
|
||||
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
||||
detail, ok := subObj["detail"]
|
||||
if ok {
|
||||
subObj["detail"] = detail.(string)
|
||||
} else {
|
||||
subObj["detail"] = "auto"
|
||||
}
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: subObj["url"].(string),
|
||||
Detail: subObj["detail"].(string),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
RelayModeCompletions
|
||||
RelayModeEmbeddings
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeEdits
|
||||
RelayModeMidjourneyImagine
|
||||
RelayModeMidjourneyDescribe
|
||||
RelayModeMidjourneyBlend
|
||||
RelayModeMidjourneyChange
|
||||
RelayModeMidjourneySimpleChange
|
||||
RelayModeMidjourneyNotify
|
||||
RelayModeMidjourneyTaskFetch
|
||||
RelayModeMidjourneyTaskFetchByCondition
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
}
|
||||
var input []string
|
||||
switch r.Input.(type) {
|
||||
case string:
|
||||
input = []string{r.Input.(string)}
|
||||
case []any:
|
||||
input = make([]string, 0, len(r.Input.([]any)))
|
||||
for _, item := range r.Input.([]any) {
|
||||
if str, ok := item.(string); ok {
|
||||
input = append(input, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
type AudioRequest struct {
|
||||
Model string `json:"model"`
|
||||
Voice string `json:"voice"`
|
||||
Input string `json:"input"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens uint `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type TextRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens uint `json:"max_tokens"`
|
||||
//Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Size string `json:"size"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
}
|
||||
|
||||
type AudioResponse struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param string `json:"param"`
|
||||
Code any `json:"code"`
|
||||
}
|
||||
|
||||
type OpenAIErrorWithStatusCode struct {
|
||||
OpenAIError
|
||||
StatusCode int `json:"status_code"`
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
Error OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type OpenAITextResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type OpenAITextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIEmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Created int `json:"created"`
|
||||
Data []struct {
|
||||
Url string `json:"url"`
|
||||
B64Json string `json:"b64_json"`
|
||||
}
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseSimple struct {
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type CompletionsStreamResponse struct {
|
||||
Choices []struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type MidjourneyRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NotifyHook string `json:"notifyHook"`
|
||||
Action string `json:"action"`
|
||||
Index int `json:"index"`
|
||||
State string `json:"state"`
|
||||
TaskId string `json:"taskId"`
|
||||
Base64Array []string `json:"base64Array"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type MidjourneyResponse struct {
|
||||
Code int `json:"code"`
|
||||
Description string `json:"description"`
|
||||
Properties interface{} `json:"properties"`
|
||||
Result string `json:"result"`
|
||||
return err
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||
relayMode = RelayModeChatCompletions
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
||||
relayMode = RelayModeCompletions
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
relayMode = RelayModeModerations
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
relayMode = RelayModeImagesGenerations
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
||||
relayMode = RelayModeEdits
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
relayMode = RelayModeAudioSpeech
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
relayMode = RelayModeAudioTranscription
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
relayMode = RelayModeAudioTranslation
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
retryTimes := common.RetryTimes
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
channelId := c.GetInt("channel_id")
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
openaiErr := relayHandler(c, relayMode)
|
||||
retryLogStr := fmt.Sprintf("重试:%d", channelId)
|
||||
if openaiErr != nil {
|
||||
go processChannelError(c, channelId, openaiErr)
|
||||
} else {
|
||||
retryTimes = 0
|
||||
}
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
switch relayMode {
|
||||
case RelayModeImagesGenerations:
|
||||
err = relayImageHelper(c, relayMode)
|
||||
case RelayModeAudioSpeech:
|
||||
fallthrough
|
||||
case RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case RelayModeAudioTranscription:
|
||||
err = relayAudioHelper(c, relayMode)
|
||||
default:
|
||||
err = relayTextHelper(c, relayMode)
|
||||
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||
break
|
||||
}
|
||||
channelId = channel.Id
|
||||
retryLogStr += fmt.Sprintf("->%d", channel.Id)
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
openaiErr = relayHandler(c, relayMode)
|
||||
if openaiErr != nil {
|
||||
go processChannelError(c, channelId, openaiErr)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
retryTimesStr := c.Query("retry")
|
||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||
if retryTimesStr == "" {
|
||||
retryTimes = common.RetryTimes
|
||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
if retryTimes > 0 {
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message))
|
||||
} else {
|
||||
if err.StatusCode == http.StatusTooManyRequests {
|
||||
//err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||
c.JSON(err.StatusCode, gin.H{
|
||||
"error": err.OpenAIError,
|
||||
})
|
||||
}
|
||||
channelId := c.GetInt("channel_id")
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
|
||||
channelId := c.GetInt("channel_id")
|
||||
channelName := c.GetString("channel_name")
|
||||
disableChannel(channelId, channelName, err.Message)
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
||||
if openaiErr == nil {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
return false
|
||||
}
|
||||
if _, ok := c.Get("specific_channel_id"); ok {
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode/100 == 5 {
|
||||
// 超时不重试
|
||||
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
if openaiErr.LocalError {
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode/100 == 2 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
|
||||
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
||||
channelName := c.GetString("channel_name")
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func RelayMidjourney(c *gin.Context) {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
|
||||
relayMode = RelayModeMidjourneyImagine
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
|
||||
relayMode = RelayModeMidjourneyBlend
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
|
||||
relayMode = RelayModeMidjourneyDescribe
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
|
||||
relayMode = RelayModeMidjourneyNotify
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
|
||||
relayMode = RelayModeMidjourneyChange
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
|
||||
relayMode = RelayModeMidjourneyChange
|
||||
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
|
||||
relayMode = RelayModeMidjourneyTaskFetch
|
||||
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
|
||||
relayMode = RelayModeMidjourneyTaskFetchByCondition
|
||||
}
|
||||
|
||||
var err *MidjourneyResponse
|
||||
relayMode := c.GetInt("relay_mode")
|
||||
var err *dto.MidjourneyResponse
|
||||
switch relayMode {
|
||||
case RelayModeMidjourneyNotify:
|
||||
err = relayMidjourneyNotify(c)
|
||||
case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
|
||||
err = relayMidjourneyTask(c, relayMode)
|
||||
case relayconstant.RelayModeMidjourneyNotify:
|
||||
err = relay.RelayMidjourneyNotify(c)
|
||||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||
err = relay.RelayMidjourneyTask(c, relayMode)
|
||||
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||||
err = relay.RelayMidjourneyTaskImageSeed(c)
|
||||
case relayconstant.RelayModeSwapFace:
|
||||
err = relay.RelaySwapFace(c)
|
||||
default:
|
||||
err = relayMidjourneySubmit(c, relayMode)
|
||||
err = relay.RelayMidjourneySubmit(c, relayMode)
|
||||
}
|
||||
//err = relayMidjourneySubmit(c, relayMode)
|
||||
log.Println(err)
|
||||
if err != nil {
|
||||
retryTimesStr := c.Query("retry")
|
||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||
if retryTimesStr == "" {
|
||||
retryTimes = common.RetryTimes
|
||||
}
|
||||
if retryTimes > 0 {
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
||||
} else {
|
||||
if err.Code == 30 {
|
||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||
}
|
||||
c.JSON(400, gin.H{
|
||||
"error": err.Description + " " + err.Result,
|
||||
})
|
||||
statusCode := http.StatusBadRequest
|
||||
if err.Code == 30 {
|
||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||
statusCode = http.StatusTooManyRequests
|
||||
}
|
||||
c.JSON(statusCode, gin.H{
|
||||
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
||||
"type": "upstream_error",
|
||||
"code": err.Code,
|
||||
})
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result))
|
||||
//if shouldDisableChannel(&err.OpenAIError) {
|
||||
// channelId := c.GetInt("channel_id")
|
||||
// channelName := c.GetString("channel_name")
|
||||
// disableChannel(channelId, channelName, err.Result)
|
||||
//};''''''''''''''''''''''''''''''''
|
||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||
}
|
||||
}
|
||||
|
||||
func RelayNotImplemented(c *gin.Context) {
|
||||
err := OpenAIError{
|
||||
err := dto.OpenAIError{
|
||||
Message: "API not implemented",
|
||||
Type: "new_api_error",
|
||||
Param: "",
|
||||
@@ -426,7 +166,7 @@ func RelayNotImplemented(c *gin.Context) {
|
||||
}
|
||||
|
||||
func RelayNotFound(c *gin.Context) {
|
||||
err := OpenAIError{
|
||||
err := dto.OpenAIError{
|
||||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||||
Type: "invalid_request_error",
|
||||
Param: "",
|
||||
|
||||
116
controller/telegram.go
Normal file
116
controller/telegram.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"sort"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TelegramBind(c *gin.Context) {
|
||||
if !common.TelegramOAuthEnabled {
|
||||
c.JSON(200, gin.H{
|
||||
"message": "管理员未开启通过 Telegram 登录以及注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
params := c.Request.URL.Query()
|
||||
if !checkTelegramAuthorization(params, common.TelegramBotToken) {
|
||||
c.JSON(200, gin.H{
|
||||
"message": "无效的请求",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
telegramId := params["id"][0]
|
||||
if model.IsTelegramIdAlreadyTaken(telegramId) {
|
||||
c.JSON(200, gin.H{
|
||||
"message": "该 Telegram 账户已被绑定",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user := model.User{Id: id.(int)}
|
||||
if err := user.FillUserById(); err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
user.TelegramId = telegramId
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(302, "/setting")
|
||||
}
|
||||
|
||||
func TelegramLogin(c *gin.Context) {
|
||||
if !common.TelegramOAuthEnabled {
|
||||
c.JSON(200, gin.H{
|
||||
"message": "管理员未开启通过 Telegram 登录以及注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
params := c.Request.URL.Query()
|
||||
if !checkTelegramAuthorization(params, common.TelegramBotToken) {
|
||||
c.JSON(200, gin.H{
|
||||
"message": "无效的请求",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
telegramId := params["id"][0]
|
||||
user := model.User{TelegramId: telegramId}
|
||||
if err := user.FillUserByTelegramId(); err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func checkTelegramAuthorization(params map[string][]string, token string) bool {
|
||||
strs := []string{}
|
||||
var hash = ""
|
||||
for k, v := range params {
|
||||
if k == "hash" {
|
||||
hash = v[0]
|
||||
continue
|
||||
}
|
||||
strs = append(strs, k+"="+v[0])
|
||||
}
|
||||
sort.Strings(strs)
|
||||
var imploded = ""
|
||||
for _, s := range strs {
|
||||
if imploded != "" {
|
||||
imploded += "\n"
|
||||
}
|
||||
imploded += s
|
||||
}
|
||||
sha256hash := sha256.New()
|
||||
io.WriteString(sha256hash, token)
|
||||
hmachash := hmac.New(sha256.New, sha256hash.Sum(nil))
|
||||
io.WriteString(hmachash, imploded)
|
||||
ss := hex.EncodeToString(hmachash.Sum(nil))
|
||||
return hash == ss
|
||||
}
|
||||
@@ -2,14 +2,17 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Calcium-Ion/go-epay/epay"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
epay "github.com/star-horizon/go-epay"
|
||||
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -28,7 +31,7 @@ func GetEpayClient() *epay.Client {
|
||||
if common.PayAddress == "" || common.EpayId == "" || common.EpayKey == "" {
|
||||
return nil
|
||||
}
|
||||
withUrl, err := epay.NewClientWithUrl(&epay.Config{
|
||||
withUrl, err := epay.NewClient(&epay.Config{
|
||||
PartnerID: common.EpayId,
|
||||
Key: common.EpayKey,
|
||||
}, common.PayAddress)
|
||||
@@ -38,31 +41,46 @@ func GetEpayClient() *epay.Client {
|
||||
return withUrl
|
||||
}
|
||||
|
||||
func GetAmount(count float64, user model.User) float64 {
|
||||
func getPayMoney(amount float64, user model.User) float64 {
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
// 别问为什么用float64,问就是这么点钱没必要
|
||||
topupGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
amount := count * common.Price * topupGroupRatio
|
||||
return amount
|
||||
payMoney := amount * common.Price * topupGroupRatio
|
||||
return payMoney
|
||||
}
|
||||
|
||||
func getMinTopup() int {
|
||||
minTopup := common.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||
}
|
||||
return minTopup
|
||||
}
|
||||
|
||||
func RequestEpay(c *gin.Context) {
|
||||
var req EpayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
if req.Amount < 1 {
|
||||
c.JSON(200, gin.H{"message": "充值金额不能小于1", "data": 10})
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
amount := GetAmount(float64(req.Amount), *user)
|
||||
payMoney := getPayMoney(float64(req.Amount), *user)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
|
||||
var payType epay.PurchaseType
|
||||
if req.PaymentMethod == "zfb" {
|
||||
@@ -72,11 +90,10 @@ func RequestEpay(c *gin.Context) {
|
||||
req.PaymentMethod = "wxpay"
|
||||
payType = epay.WechatPay
|
||||
}
|
||||
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(common.ServerAddress + "/log")
|
||||
notifyUrl, _ := url.Parse(common.ServerAddress + "/api/user/epay/notify")
|
||||
tradeNo := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
payMoney := amount
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
|
||||
@@ -95,9 +112,13 @@ func RequestEpay(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
amount := req.Amount
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / int(common.QuotaPerUnit)
|
||||
}
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: "A" + tradeNo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
@@ -111,6 +132,33 @@ func RequestEpay(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
|
||||
}
|
||||
|
||||
// tradeNo lock
|
||||
var orderLocks sync.Map
|
||||
var createLock sync.Mutex
|
||||
|
||||
// LockOrder 尝试对给定订单号加锁
|
||||
func LockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
createLock.Lock()
|
||||
defer createLock.Unlock()
|
||||
lock, ok = orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
lock = new(sync.Mutex)
|
||||
orderLocks.Store(tradeNo, lock)
|
||||
}
|
||||
}
|
||||
lock.(*sync.Mutex).Lock()
|
||||
}
|
||||
|
||||
// UnlockOrder 释放给定订单号的锁
|
||||
func UnlockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if ok {
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func EpayNotify(c *gin.Context) {
|
||||
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
|
||||
r[t] = c.Request.URL.Query().Get(t)
|
||||
@@ -122,6 +170,7 @@ func EpayNotify(c *gin.Context) {
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
return
|
||||
}
|
||||
}
|
||||
verifyInfo, err := client.Verify(params)
|
||||
@@ -135,11 +184,19 @@ func EpayNotify(c *gin.Context) {
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
log.Println("易支付回调签名验证失败")
|
||||
return
|
||||
}
|
||||
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
log.Println(verifyInfo)
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
|
||||
if topUp == nil {
|
||||
log.Printf("易支付回调未找到订单: %v", verifyInfo)
|
||||
return
|
||||
}
|
||||
if topUp.Status == "pending" {
|
||||
topUp.Status = "success"
|
||||
err := topUp.Update()
|
||||
@@ -149,13 +206,13 @@ func EpayNotify(c *gin.Context) {
|
||||
}
|
||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
||||
//user.Quota += topUp.Amount * 500000
|
||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000)
|
||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*500000), topUp.Money))
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
|
||||
}
|
||||
} else {
|
||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||
@@ -169,12 +226,17 @@ func RequestAmount(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
if req.Amount < 1 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额不能小于1"})
|
||||
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
payMoney := GetAmount(float64(req.Amount), *user)
|
||||
payMoney := getPayMoney(float64(req.Amount), *user)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -683,7 +684,7 @@ func ManageUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := user.HardDelete(); err != nil {
|
||||
if err := user.Delete(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
@@ -724,7 +725,7 @@ func ManageUser(c *gin.Context) {
|
||||
user.Role = common.RoleCommonUser
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
if err := user.UpdateAll(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
@@ -789,7 +790,11 @@ type topUpRequest struct {
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
var lock = sync.Mutex{}
|
||||
|
||||
func TopUp(c *gin.Context) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
req := topUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ version: '3.4'
|
||||
services:
|
||||
new-api:
|
||||
image: calciumion/new-api:latest
|
||||
# build: .
|
||||
container_name: new-api
|
||||
restart: always
|
||||
command: --log-dir /app/logs
|
||||
|
||||
13
dto/audio.go
Normal file
13
dto/audio.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package dto
|
||||
|
||||
type TextToSpeechRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Input string `json:"input" binding:"required"`
|
||||
Voice string `json:"voice" binding:"required"`
|
||||
Speed float64 `json:"speed"`
|
||||
ResponseFormat string `json:"response_format"`
|
||||
}
|
||||
|
||||
type AudioResponse struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
20
dto/dalle.go
Normal file
20
dto/dalle.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package dto
|
||||
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Created int `json:"created"`
|
||||
Data []struct {
|
||||
Url string `json:"url"`
|
||||
B64Json string `json:"b64_json"`
|
||||
}
|
||||
}
|
||||
55
dto/error.go
Normal file
55
dto/error.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package dto
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param string `json:"param"`
|
||||
Code any `json:"code"`
|
||||
}
|
||||
|
||||
type OpenAIErrorWithStatusCode struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
LocalError bool
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
Response struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
func (e GeneralErrorResponse) ToMessage() string {
|
||||
if e.Error.Message != "" {
|
||||
return e.Error.Message
|
||||
}
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
if e.Msg != "" {
|
||||
return e.Msg
|
||||
}
|
||||
if e.Err != "" {
|
||||
return e.Err
|
||||
}
|
||||
if e.ErrorMsg != "" {
|
||||
return e.ErrorMsg
|
||||
}
|
||||
if e.Header.Message != "" {
|
||||
return e.Header.Message
|
||||
}
|
||||
if e.Response.Error.Message != "" {
|
||||
return e.Response.Error.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
95
dto/midjourney.go
Normal file
95
dto/midjourney.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package dto
|
||||
|
||||
//type SimpleMjRequest struct {
|
||||
// Prompt string `json:"prompt"`
|
||||
// CustomId string `json:"customId"`
|
||||
// Action string `json:"action"`
|
||||
// Content string `json:"content"`
|
||||
//}
|
||||
|
||||
type SwapFaceRequest struct {
|
||||
SourceBase64 string `json:"sourceBase64"`
|
||||
TargetBase64 string `json:"targetBase64"`
|
||||
}
|
||||
|
||||
type MidjourneyRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
CustomId string `json:"customId"`
|
||||
BotType string `json:"botType"`
|
||||
NotifyHook string `json:"notifyHook"`
|
||||
Action string `json:"action"`
|
||||
Index int `json:"index"`
|
||||
State string `json:"state"`
|
||||
TaskId string `json:"taskId"`
|
||||
Base64Array []string `json:"base64Array"`
|
||||
Content string `json:"content"`
|
||||
MaskBase64 string `json:"maskBase64"`
|
||||
}
|
||||
|
||||
type MidjourneyResponse struct {
|
||||
Code int `json:"code"`
|
||||
Description string `json:"description"`
|
||||
Properties interface{} `json:"properties"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
type MidjourneyResponseWithStatusCode struct {
|
||||
StatusCode int `json:"statusCode"`
|
||||
Response MidjourneyResponse
|
||||
}
|
||||
|
||||
type MidjourneyDto struct {
|
||||
MjId string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
CustomId string `json:"customId"`
|
||||
BotType string `json:"botType"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submitTime"`
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
Buttons any `json:"buttons"`
|
||||
MaskBase64 string `json:"maskBase64"`
|
||||
Properties *Properties `json:"properties"`
|
||||
}
|
||||
|
||||
type MidjourneyStatus struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
type MidjourneyWithoutStatus struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
}
|
||||
|
||||
type ActionButton struct {
|
||||
CustomId any `json:"customId"`
|
||||
Emoji any `json:"emoji"`
|
||||
Label any `json:"label"`
|
||||
Type any `json:"type"`
|
||||
Style any `json:"style"`
|
||||
}
|
||||
|
||||
type Properties struct {
|
||||
FinalPrompt string `json:"finalPrompt"`
|
||||
FinalZhPrompt string `json:"finalZhPrompt"`
|
||||
}
|
||||
6
dto/sensitive.go
Normal file
6
dto/sensitive.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package dto
|
||||
|
||||
type SensitiveResponse struct {
|
||||
SensitiveWords []string `json:"sensitive_words"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
141
dto/text_request.go
Normal file
141
dto/text_request.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
}
|
||||
var input []string
|
||||
switch r.Input.(type) {
|
||||
case string:
|
||||
input = []string{r.Input.(string)}
|
||||
case []any:
|
||||
input = make([]string, 0, len(r.Input.([]any)))
|
||||
for _, item := range r.Input.([]any) {
|
||||
if str, ok := item.(string); ok {
|
||||
input = append(input, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCalls any `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type MediaMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl any `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type MessageImageUrl struct {
|
||||
Url string `json:"url"`
|
||||
Detail string `json:"detail"`
|
||||
}
|
||||
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
)
|
||||
|
||||
func (m Message) StringContent() string {
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
return stringContent
|
||||
}
|
||||
return string(m.Content)
|
||||
}
|
||||
|
||||
func (m Message) IsStringContent() bool {
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m Message) ParseContent() []MediaMessage {
|
||||
var contentList []MediaMessage
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeText,
|
||||
Text: stringContent,
|
||||
})
|
||||
return contentList
|
||||
}
|
||||
var arrayContent []json.RawMessage
|
||||
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
|
||||
for _, contentItem := range arrayContent {
|
||||
var contentMap map[string]any
|
||||
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
|
||||
continue
|
||||
}
|
||||
switch contentMap["type"] {
|
||||
case ContentTypeText:
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeText,
|
||||
Text: subStr,
|
||||
})
|
||||
}
|
||||
case ContentTypeImageURL:
|
||||
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
||||
detail, ok := subObj["detail"]
|
||||
if ok {
|
||||
subObj["detail"] = detail.(string)
|
||||
} else {
|
||||
subObj["detail"] = "auto"
|
||||
}
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: subObj["url"].(string),
|
||||
Detail: subObj["detail"].(string),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
89
dto/text_response.go
Normal file
89
dto/text_response.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package dto
|
||||
|
||||
type TextResponseWithError struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Data []OpenAIEmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
Error OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
Error OpenAIError `json:"error"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type OpenAITextResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type OpenAITextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIEmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls any `json:"tool_calls,omitempty"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseSimple struct {
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type CompletionsStreamResponse struct {
|
||||
Choices []struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
24
go.mod
24
go.mod
@@ -4,22 +4,23 @@ module one-api
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/chai2010/webp v1.1.1
|
||||
github.com/Calcium-Ion/go-epay v0.0.2
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
||||
github.com/gin-contrib/cors v1.4.0
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
github.com/gin-contrib/static v0.0.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/go-playground/validator/v10 v10.16.0
|
||||
github.com/go-playground/validator/v10 v10.19.0
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/pkoukk/tiktoken-go v0.1.6
|
||||
github.com/samber/lo v1.38.1
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
|
||||
golang.org/x/crypto v0.17.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/image v0.15.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
@@ -27,12 +28,13 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
@@ -50,7 +52,7 @@ require (
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
@@ -63,10 +65,10 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
57
go.sum
57
go.sum
@@ -1,10 +1,16 @@
|
||||
github.com/Calcium-Ion/go-epay v0.0.1 h1:cRCvwNTkPmmLM5od0p4w0cTcYcAPaAVLYr41ujseDcc=
|
||||
github.com/Calcium-Ion/go-epay v0.0.1/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
|
||||
github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
|
||||
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
@@ -17,8 +23,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
|
||||
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
|
||||
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
|
||||
@@ -47,10 +53,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
|
||||
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
|
||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
|
||||
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
|
||||
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||
@@ -79,8 +83,6 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
|
||||
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
|
||||
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
|
||||
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
@@ -106,12 +108,10 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
|
||||
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
|
||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
@@ -141,6 +141,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
|
||||
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
|
||||
@@ -155,9 +157,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
@@ -176,17 +177,21 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -197,16 +202,12 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
7
main.go
7
main.go
@@ -13,16 +13,17 @@ import (
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/router"
|
||||
"one-api/service"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
//go:embed web/build
|
||||
//go:embed web/dist
|
||||
var buildFS embed.FS
|
||||
|
||||
//go:embed web/build/index.html
|
||||
//go:embed web/dist/index.html
|
||||
var indexPage []byte
|
||||
|
||||
func main() {
|
||||
@@ -105,7 +106,7 @@ func main() {
|
||||
common.SysLog("pprof enabled")
|
||||
}
|
||||
|
||||
controller.InitTokenEncoders()
|
||||
service.InitTokenEncoders()
|
||||
|
||||
// Initialize HTTP server
|
||||
server := gin.New()
|
||||
|
||||
14
makefile
Normal file
14
makefile
Normal file
@@ -0,0 +1,14 @@
|
||||
FRONTEND_DIR = ./web
|
||||
BACKEND_DIR = .
|
||||
|
||||
.PHONY: all build-frontend start-backend
|
||||
|
||||
all: build-frontend start-backend
|
||||
|
||||
build-frontend:
|
||||
@echo "Building frontend..."
|
||||
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build npm run build
|
||||
|
||||
start-backend:
|
||||
@echo "Starting backend dev server..."
|
||||
@cd $(BACKEND_DIR) && go run main.go &
|
||||
@@ -100,16 +100,16 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
token, err := model.ValidateUserToken(key)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !userEnabled {
|
||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
@@ -125,17 +125,11 @@ func TokenAuth() func(c *gin.Context) {
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
requestURL := c.Request.URL.String()
|
||||
consumeQuota := true
|
||||
if strings.HasPrefix(requestURL, "/v1/models") {
|
||||
consumeQuota = false
|
||||
}
|
||||
c.Set("consume_quota", consumeQuota)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("channelId", parts[1])
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
} else {
|
||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,11 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -19,62 +23,27 @@ func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("channelId")
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
c.Set("group", userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
channel, err = model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
var modelRequest ModelRequest
|
||||
var err error
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
|
||||
// Midjourney
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "midjourney"
|
||||
}
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求: "+err.Error())
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
if modelRequest.Model == "" {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
modelRequest.Model = "tts-1"
|
||||
} else {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
}
|
||||
// check token model mapping
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
if modelLimitEnable {
|
||||
@@ -87,58 +56,137 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
if tokenModelLimit != nil {
|
||||
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
||||
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// token model limit is empty, all models are not allowed
|
||||
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
c.Set("group", userGroup)
|
||||
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
if shouldSelectChannel {
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
}
|
||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
||||
return
|
||||
}
|
||||
if channel == nil {
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
||||
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
||||
return
|
||||
}
|
||||
if channel == nil {
|
||||
abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
ban := true
|
||||
// parse *int to bool
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
c.Set("auto_ban", ban)
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
}
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
var modelRequest ModelRequest
|
||||
shouldSelectChannel := true
|
||||
var err error
|
||||
if strings.Contains(c.Request.URL.Path, "/mj/") {
|
||||
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
|
||||
shouldSelectChannel = false
|
||||
} else {
|
||||
midjourneyRequest := dto.MidjourneyRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
|
||||
if err != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
|
||||
return nil, false, err
|
||||
}
|
||||
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
||||
if mjErr != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
|
||||
return nil, false, fmt.Errorf(mjErr.Description)
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
if !success {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
|
||||
} else {
|
||||
// task fetch, task fetch by condition, notify
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
}
|
||||
modelRequest.Model = midjourneyModel
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||
return nil, false, err
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
if modelRequest.Model == "" {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
modelRequest.Model = "tts-1"
|
||||
} else {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
}
|
||||
return &modelRequest, shouldSelectChannel, nil
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
ban := true
|
||||
// parse *int to bool
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
}
|
||||
c.Set("auto_ban", ban)
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Set("original_model", modelName) // for retry
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||
@@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
c.Abort()
|
||||
common.LogError(c.Request.Context(), message)
|
||||
}
|
||||
|
||||
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
|
||||
c.JSON(statusCode, gin.H{
|
||||
"description": description,
|
||||
"type": "new_api_error",
|
||||
"code": code,
|
||||
})
|
||||
c.Abort()
|
||||
common.LogError(c.Request.Context(), description)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
@@ -27,8 +28,7 @@ func GetGroupModels(group string) []string {
|
||||
return models
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
var abilities []Ability
|
||||
func getPriority(group string, model string, retry int) (int, error) {
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
@@ -36,9 +36,55 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
var err error = nil
|
||||
var priorities []int
|
||||
err := DB.Model(&Ability{}).
|
||||
Select("DISTINCT(priority)").
|
||||
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||
|
||||
if err != nil {
|
||||
// 处理错误
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 确定要使用的优先级
|
||||
var priorityToUse int
|
||||
if retry >= len(priorities) {
|
||||
// 如果重试次数大于优先级数,则使用最小的优先级
|
||||
priorityToUse = priorities[len(priorities)-1]
|
||||
} else {
|
||||
priorityToUse = priorities[retry]
|
||||
}
|
||||
return priorityToUse, nil
|
||||
}
|
||||
|
||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
trueVal = "true"
|
||||
}
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
if retry != 0 {
|
||||
priority, err := getPriority(group, model, retry)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||
} else {
|
||||
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
|
||||
}
|
||||
}
|
||||
|
||||
return channelQuery
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
var abilities []Ability
|
||||
|
||||
var err error = nil
|
||||
channelQuery := getChannelQuery(group, model, retry)
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
||||
} else {
|
||||
@@ -52,21 +98,16 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
// Randomly choose one
|
||||
weightSum := uint(0)
|
||||
for _, ability_ := range abilities {
|
||||
weightSum += ability_.Weight
|
||||
weightSum += ability_.Weight + 10
|
||||
}
|
||||
if weightSum == 0 {
|
||||
// All weight is 0, randomly choose one
|
||||
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
|
||||
} else {
|
||||
// Randomly choose one
|
||||
weight := common.GetRandomInt(int(weightSum))
|
||||
for _, ability_ := range abilities {
|
||||
weight -= int(ability_.Weight)
|
||||
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
|
||||
if weight <= 0 {
|
||||
channel.Id = ability_.ChannelId
|
||||
break
|
||||
}
|
||||
// Randomly choose one
|
||||
weight := common.GetRandomInt(int(weightSum))
|
||||
for _, ability_ := range abilities {
|
||||
weight -= int(ability_.Weight) + 10
|
||||
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
|
||||
if weight <= 0 {
|
||||
channel.Id = ability_.ChannelId
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -147,7 +188,12 @@ func FixAbility() (int, error) {
|
||||
return 0, err
|
||||
}
|
||||
var channels []Channel
|
||||
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
|
||||
|
||||
if len(abilityChannelIds) == 0 {
|
||||
err = DB.Find(&channels).Error
|
||||
} else {
|
||||
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -265,14 +265,14 @@ func SyncChannelCache(frequency int) {
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
|
||||
// if memory cache is disabled, get channel directly from database
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model)
|
||||
return GetRandomSatisfiedChannel(group, model, retry)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
@@ -280,35 +280,44 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
endIdx := len(channels)
|
||||
// choose by priority
|
||||
firstChannel := channels[0]
|
||||
if firstChannel.GetPriority() > 0 {
|
||||
for i := range channels {
|
||||
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
||||
endIdx = i
|
||||
break
|
||||
}
|
||||
|
||||
uniquePriorities := make(map[int]bool)
|
||||
for _, channel := range channels {
|
||||
uniquePriorities[int(channel.GetPriority())] = true
|
||||
}
|
||||
var sortedUniquePriorities []int
|
||||
for priority := range uniquePriorities {
|
||||
sortedUniquePriorities = append(sortedUniquePriorities, priority)
|
||||
}
|
||||
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
|
||||
|
||||
if retry >= len(uniquePriorities) {
|
||||
retry = len(uniquePriorities) - 1
|
||||
}
|
||||
targetPriority := int64(sortedUniquePriorities[retry])
|
||||
|
||||
// get the priority for the given retry number
|
||||
var targetChannels []*Channel
|
||||
for _, channel := range channels {
|
||||
if channel.GetPriority() == targetPriority {
|
||||
targetChannels = append(targetChannels, channel)
|
||||
}
|
||||
}
|
||||
|
||||
// 平滑系数
|
||||
smoothingFactor := 10
|
||||
// Calculate the total weight of all channels up to endIdx
|
||||
totalWeight := 0
|
||||
for _, channel := range channels[:endIdx] {
|
||||
totalWeight += channel.GetWeight()
|
||||
for _, channel := range targetChannels {
|
||||
totalWeight += channel.GetWeight() + smoothingFactor
|
||||
}
|
||||
|
||||
if totalWeight == 0 {
|
||||
// If all weights are 0, select a channel randomly
|
||||
return channels[rand.Intn(endIdx)], nil
|
||||
}
|
||||
|
||||
// Generate a random value in the range [0, totalWeight)
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
|
||||
// Find a channel based on its weight
|
||||
for _, channel := range channels[:endIdx] {
|
||||
randomWeight -= channel.GetWeight()
|
||||
if randomWeight <= 0 {
|
||||
for _, channel := range targetChannels {
|
||||
randomWeight -= channel.GetWeight() + smoothingFactor
|
||||
if randomWeight < 0 {
|
||||
return channel, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,9 @@ import (
|
||||
type Channel struct {
|
||||
Id int `json:"id"`
|
||||
Type int `json:"type" gorm:"default:0"`
|
||||
Key string `json:"key" gorm:"not null;index"`
|
||||
Key string `json:"key" gorm:"not null"`
|
||||
OpenAIOrganization *string `json:"openai_organization"`
|
||||
TestModel *string `json:"test_model"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index"`
|
||||
Weight *uint `json:"weight" gorm:"default:0"`
|
||||
@@ -43,21 +44,39 @@ func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Chan
|
||||
return channels, err
|
||||
}
|
||||
|
||||
func SearchChannels(keyword string, group string) (channels []*Channel, err error) {
|
||||
func SearchChannels(keyword string, group string, model string) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
keyCol := "`key`"
|
||||
groupCol := "`group`"
|
||||
modelsCol := "`models`"
|
||||
|
||||
// 如果是 PostgreSQL,使用双引号
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
groupCol = `"group"`
|
||||
modelsCol = `"models"`
|
||||
}
|
||||
|
||||
// 构造基础查询
|
||||
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
|
||||
|
||||
// 构造WHERE子句
|
||||
var whereClause string
|
||||
var args []interface{}
|
||||
if group != "" {
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
}
|
||||
err = DB.Omit("key").Where("(id = ? or name LIKE ? or "+keyCol+" = ?) and "+groupCol+" LIKE ?", common.String2Int(keyword), keyword+"%", keyword, "%"+group+"%").Find(&channels).Error
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?"
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%")
|
||||
} else {
|
||||
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
|
||||
}
|
||||
return channels, err
|
||||
|
||||
// 执行查询
|
||||
err := baseQuery.Where(whereClause, args...).Find(&channels).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
|
||||
@@ -85,7 +85,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
}
|
||||
if common.DataExportEnabled {
|
||||
common.SafeGoroutine(func() {
|
||||
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp())
|
||||
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"one-api/common"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -60,6 +62,7 @@ func chooseDB() (*gorm.DB, error) {
|
||||
dsn += "?parseTime=true"
|
||||
}
|
||||
}
|
||||
common.UsingMySQL = true
|
||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
@@ -90,6 +93,12 @@ func InitDB() (err error) {
|
||||
if !common.IsMasterNode {
|
||||
return nil
|
||||
}
|
||||
if common.UsingMySQL {
|
||||
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
||||
_, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
|
||||
_, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
|
||||
_, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
|
||||
}
|
||||
common.SysLog("database migration started")
|
||||
err = db.AutoMigrate(&Channel{})
|
||||
if err != nil {
|
||||
@@ -148,3 +157,33 @@ func CloseDB() error {
|
||||
err = sqlDB.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
lastPingTime time.Time
|
||||
pingMutex sync.Mutex
|
||||
)
|
||||
|
||||
func PingDB() error {
|
||||
pingMutex.Lock()
|
||||
defer pingMutex.Unlock()
|
||||
|
||||
if time.Since(lastPingTime) < time.Second*10 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlDB, err := DB.DB()
|
||||
if err != nil {
|
||||
log.Printf("Error getting sql.DB from GORM: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = sqlDB.Ping()
|
||||
if err != nil {
|
||||
log.Printf("Error pinging DB: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
lastPingTime = time.Now()
|
||||
common.SysLog("Database pinged successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,21 +4,23 @@ type Midjourney struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action"`
|
||||
Action string `json:"action" gorm:"type:varchar(40);index"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
SubmitTime int64 `json:"submit_time" gorm:"index"`
|
||||
StartTime int64 `json:"start_time" gorm:"index"`
|
||||
FinishTime int64 `json:"finish_time" gorm:"index"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
Status string `json:"status" gorm:"type:varchar(20);index"`
|
||||
Progress string `json:"progress" gorm:"type:varchar(30);index"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
Quota int `json:"quota"`
|
||||
Buttons string `json:"buttons"`
|
||||
Properties string `json:"properties"`
|
||||
}
|
||||
|
||||
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
|
||||
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -30,6 +31,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
|
||||
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
|
||||
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
|
||||
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
|
||||
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
|
||||
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
||||
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
|
||||
@@ -42,12 +44,14 @@ func InitOptionMap() {
|
||||
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
|
||||
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
||||
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
|
||||
common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
|
||||
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
|
||||
common.OptionMap["SMTPServer"] = ""
|
||||
common.OptionMap["SMTPFrom"] = ""
|
||||
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
|
||||
common.OptionMap["SMTPAccount"] = ""
|
||||
common.OptionMap["SMTPToken"] = ""
|
||||
common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled)
|
||||
common.OptionMap["Notice"] = ""
|
||||
common.OptionMap["About"] = ""
|
||||
common.OptionMap["HomePageContent"] = ""
|
||||
@@ -56,12 +60,16 @@ func InitOptionMap() {
|
||||
common.OptionMap["Logo"] = common.Logo
|
||||
common.OptionMap["ServerAddress"] = ""
|
||||
common.OptionMap["PayAddress"] = ""
|
||||
common.OptionMap["CustomCallbackAddress"] = ""
|
||||
common.OptionMap["EpayId"] = ""
|
||||
common.OptionMap["EpayKey"] = ""
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(common.Price, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
common.OptionMap["TelegramBotName"] = ""
|
||||
common.OptionMap["WeChatServerAddress"] = ""
|
||||
common.OptionMap["WeChatServerToken"] = ""
|
||||
common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
|
||||
@@ -82,6 +90,14 @@ func InitOptionMap() {
|
||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
|
||||
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
|
||||
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
|
||||
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
|
||||
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
@@ -138,7 +154,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.ImageDownloadPermission = intValue
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(key, "Enabled") {
|
||||
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
|
||||
boolValue := value == "true"
|
||||
switch key {
|
||||
case "PasswordRegisterEnabled":
|
||||
@@ -151,12 +167,16 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.GitHubOAuthEnabled = boolValue
|
||||
case "WeChatAuthEnabled":
|
||||
common.WeChatAuthEnabled = boolValue
|
||||
case "TelegramOAuthEnabled":
|
||||
common.TelegramOAuthEnabled = boolValue
|
||||
case "TurnstileCheckEnabled":
|
||||
common.TurnstileCheckEnabled = boolValue
|
||||
case "RegisterEnabled":
|
||||
common.RegisterEnabled = boolValue
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
common.EmailDomainRestrictionEnabled = boolValue
|
||||
case "EmailAliasRestrictionEnabled":
|
||||
common.EmailAliasRestrictionEnabled = boolValue
|
||||
case "AutomaticDisableChannelEnabled":
|
||||
common.AutomaticDisableChannelEnabled = boolValue
|
||||
case "AutomaticEnableChannelEnabled":
|
||||
@@ -171,6 +191,20 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.DrawingEnabled = boolValue
|
||||
case "DataExportEnabled":
|
||||
common.DataExportEnabled = boolValue
|
||||
case "DefaultCollapseSidebar":
|
||||
common.DefaultCollapseSidebar = boolValue
|
||||
case "MjNotifyEnabled":
|
||||
constant.MjNotifyEnabled = boolValue
|
||||
case "CheckSensitiveEnabled":
|
||||
constant.CheckSensitiveEnabled = boolValue
|
||||
case "CheckSensitiveOnPromptEnabled":
|
||||
constant.CheckSensitiveOnPromptEnabled = boolValue
|
||||
//case "CheckSensitiveOnCompletionEnabled":
|
||||
// constant.CheckSensitiveOnCompletionEnabled = boolValue
|
||||
case "StopOnSensitiveEnabled":
|
||||
constant.StopOnSensitiveEnabled = boolValue
|
||||
case "SMTPSSLEnabled":
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
@@ -191,12 +225,16 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.ServerAddress = value
|
||||
case "PayAddress":
|
||||
common.PayAddress = value
|
||||
case "CustomCallbackAddress":
|
||||
common.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
common.EpayId = value
|
||||
case "EpayKey":
|
||||
common.EpayKey = value
|
||||
case "Price":
|
||||
common.Price, _ = strconv.ParseFloat(value, 64)
|
||||
case "MinTopUp":
|
||||
common.MinTopUp, _ = strconv.Atoi(value)
|
||||
case "TopupGroupRatio":
|
||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||
case "GitHubClientId":
|
||||
@@ -215,6 +253,10 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.WeChatServerToken = value
|
||||
case "WeChatAccountQRCodeImageURL":
|
||||
common.WeChatAccountQRCodeImageURL = value
|
||||
case "TelegramBotToken":
|
||||
common.TelegramBotToken = value
|
||||
case "TelegramBotName":
|
||||
common.TelegramBotName = value
|
||||
case "TurnstileSiteKey":
|
||||
common.TurnstileSiteKey = value
|
||||
case "TurnstileSecretKey":
|
||||
@@ -251,6 +293,10 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
|
||||
case "QuotaPerUnit":
|
||||
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
|
||||
case "SensitiveWords":
|
||||
constant.SensitiveWordsFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
constant.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,16 +8,17 @@ import (
|
||||
)
|
||||
|
||||
type Redemption struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index"`
|
||||
Quota int `json:"quota" gorm:"default:100"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
|
||||
Count int `json:"count" gorm:"-:all"` // only for api request
|
||||
UsedUserId int `json:"used_user_id"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index"`
|
||||
Quota int `json:"quota" gorm:"default:100"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
|
||||
Count int `json:"count" gorm:"-:all"` // only for api request
|
||||
UsedUserId int `json:"used_user_id"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) {
|
||||
@@ -55,7 +56,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
|
||||
common.RandomSleep()
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
|
||||
if err != nil {
|
||||
|
||||
@@ -10,19 +10,20 @@ import (
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index" `
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
|
||||
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index" `
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
|
||||
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
||||
@@ -47,7 +48,9 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
token, err = CacheGetTokenByKey(key)
|
||||
if err == nil {
|
||||
if token.Status == common.TokenStatusExhausted {
|
||||
return nil, errors.New("该令牌额度已用尽 token.Status == common.TokenStatusExhausted " + key)
|
||||
keyPrefix := key[:3]
|
||||
keySuffix := key[len(key)-3:]
|
||||
return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
|
||||
} else if token.Status == common.TokenStatusExpired {
|
||||
return nil, errors.New("该令牌已过期")
|
||||
}
|
||||
@@ -73,7 +76,9 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New(fmt.Sprintf("%s 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", token.Key, token.RemainQuota))
|
||||
keyPrefix := key[:3]
|
||||
keySuffix := key[len(key)-3:]
|
||||
return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ type QuotaData struct {
|
||||
Username string `json:"username" gorm:"index:idx_qdt_model_user_name,priority:2;size:64;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index:idx_qdt_model_user_name,priority:1;size:64;default:''"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_qdt_created_at,priority:2"`
|
||||
TokenUsed int `json:"token_used" gorm:"default:0"`
|
||||
Count int `json:"count" gorm:"default:0"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
}
|
||||
@@ -38,7 +39,7 @@ func UpdateQuotaData() {
|
||||
var CacheQuotaData = make(map[string]*QuotaData)
|
||||
var CacheQuotaDataLock = sync.Mutex{}
|
||||
|
||||
func logQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64) {
|
||||
func logQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64, tokenUsed int) {
|
||||
key := fmt.Sprintf("%d-%s-%s-%d", userId, username, modelName, createdAt)
|
||||
quotaData, ok := CacheQuotaData[key]
|
||||
if ok {
|
||||
@@ -52,18 +53,19 @@ func logQuotaDataCache(userId int, username string, modelName string, quota int,
|
||||
CreatedAt: createdAt,
|
||||
Count: 1,
|
||||
Quota: quota,
|
||||
TokenUsed: tokenUsed,
|
||||
}
|
||||
}
|
||||
CacheQuotaData[key] = quotaData
|
||||
}
|
||||
|
||||
func LogQuotaData(userId int, username string, modelName string, quota int, createdAt int64) {
|
||||
func LogQuotaData(userId int, username string, modelName string, quota int, createdAt int64, tokenUsed int) {
|
||||
// 只精确到小时
|
||||
createdAt = createdAt - (createdAt % 3600)
|
||||
|
||||
CacheQuotaDataLock.Lock()
|
||||
defer CacheQuotaDataLock.Unlock()
|
||||
logQuotaDataCache(userId, username, modelName, quota, createdAt)
|
||||
logQuotaDataCache(userId, username, modelName, quota, createdAt, tokenUsed)
|
||||
}
|
||||
|
||||
func SaveQuotaDataCache() {
|
||||
|
||||
@@ -3,10 +3,12 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
|
||||
@@ -21,6 +23,7 @@ type User struct {
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||
Quota int `json:"quota" gorm:"type:int;default:0"`
|
||||
@@ -70,8 +73,26 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
|
||||
return users, err
|
||||
}
|
||||
|
||||
func SearchUsers(keyword string) (users []*User, err error) {
|
||||
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
|
||||
func SearchUsers(keyword string) ([]*User, error) {
|
||||
var users []*User
|
||||
var err error
|
||||
|
||||
// 尝试将关键字转换为整数ID
|
||||
keywordInt, err := strconv.Atoi(keyword)
|
||||
if err == nil {
|
||||
// 如果转换成功,按照ID搜索用户
|
||||
err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error
|
||||
if err != nil || len(users) > 0 {
|
||||
// 如果依据ID找到用户或者发生错误,返回结果或错误
|
||||
return users, err
|
||||
}
|
||||
}
|
||||
|
||||
// 如果ID转换失败或者没有找到用户,依据其他字段进行模糊搜索
|
||||
err = DB.Unscoped().Omit("password").
|
||||
Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").
|
||||
Find(&users).Error
|
||||
|
||||
return users, err
|
||||
}
|
||||
|
||||
@@ -208,6 +229,27 @@ func (user *User) Update(updatePassword bool) error {
|
||||
if err == nil {
|
||||
if common.RedisEnabled {
|
||||
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (user *User) UpdateAll(updatePassword bool) error {
|
||||
var err error
|
||||
if updatePassword {
|
||||
user.Password, err = common.Password2Hash(user.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
newUser := *user
|
||||
DB.First(&user, user.Id)
|
||||
err = DB.Model(user).Select("*").Updates(newUser).Error
|
||||
if err == nil {
|
||||
if common.RedisEnabled {
|
||||
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
}
|
||||
}
|
||||
return err
|
||||
@@ -286,6 +328,17 @@ func (user *User) FillUserByUsername() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByTelegramId() error {
|
||||
if user.TelegramId == "" {
|
||||
return errors.New("Telegram id 为空!")
|
||||
}
|
||||
err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("该 Telegram 账户未绑定")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsEmailAlreadyTaken(email string) bool {
|
||||
return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
@@ -302,6 +355,10 @@ func IsUsernameAlreadyTaken(username string) bool {
|
||||
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsTelegramIdAlreadyTaken(telegramId string) bool {
|
||||
return DB.Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func ResetUserPasswordByEmail(email string, password string) error {
|
||||
if email == "" || password == "" {
|
||||
return errors.New("邮箱地址或密码为空!")
|
||||
|
||||
21
relay/channel/adapter.go
Normal file
21
relay/channel/adapter.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
type Adaptor interface {
|
||||
// Init IsStream bool
|
||||
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
8
relay/channel/ai360/constants.go
Normal file
8
relay/channel/ai360/constants.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package ai360
|
||||
|
||||
var ModelList = []string{
|
||||
"360GPT_S2_V9",
|
||||
"embedding-bert-512-v1",
|
||||
"embedding_s1_v1",
|
||||
"semantic_similarity_s1_v1",
|
||||
}
|
||||
80
relay/channel/ali/adaptor.go
Normal file
80
relay/channel/ali/adaptor.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
|
||||
if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
||||
}
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
if info.IsStream {
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
baiduRequest := requestOpenAI2Ali(*request)
|
||||
return baiduRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = aliStreamHandler(c, resp)
|
||||
} else {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = aliEmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = aliHandler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
8
relay/channel/ali/constants.go
Normal file
8
relay/channel/ali/constants.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package ali
|
||||
|
||||
var ModelList = []string{
|
||||
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
|
||||
"text-embedding-v1",
|
||||
}
|
||||
|
||||
var ChannelName = "ali"
|
||||
72
relay/channel/ali/dto.go
Normal file
72
relay/channel/ali/dto.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package ali
|
||||
|
||||
type AliMessage struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
//History []AliMessage `json:"history,omitempty"`
|
||||
Messages []AliMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
}
|
||||
|
||||
type AliChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliInput `json:"input,omitempty"`
|
||||
Parameters AliParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type AliEmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []AliEmbedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
type AliError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -7,119 +7,46 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
|
||||
type AliMessage struct {
|
||||
User string `json:"user"`
|
||||
Bot string `json:"bot"`
|
||||
}
|
||||
const EnableSearchModelSuffix = "-internet"
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
History []AliMessage `json:"history"`
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
}
|
||||
|
||||
type AliChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliInput `json:"input"`
|
||||
Parameters AliParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type AliEmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []AliEmbedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
type AliError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
||||
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
prompt := ""
|
||||
//prompt := ""
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, AliMessage{
|
||||
User: string(message.Content),
|
||||
Bot: "Okay",
|
||||
})
|
||||
continue
|
||||
} else {
|
||||
if i == len(request.Messages)-1 {
|
||||
prompt = string(message.Content)
|
||||
break
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
User: string(message.Content),
|
||||
Bot: string(request.Messages[i+1].Content),
|
||||
})
|
||||
i++
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
Content: message.StringContent(),
|
||||
Role: strings.ToLower(message.Role),
|
||||
})
|
||||
}
|
||||
enableSearch := false
|
||||
aliModel := request.Model
|
||||
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
|
||||
enableSearch = true
|
||||
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
||||
}
|
||||
return &AliChatRequest{
|
||||
Model: request.Model,
|
||||
Input: AliInput{
|
||||
Prompt: prompt,
|
||||
History: messages,
|
||||
//Prompt: prompt,
|
||||
Messages: messages,
|
||||
},
|
||||
Parameters: AliParameters{
|
||||
IncrementalOutput: request.Stream,
|
||||
Seed: uint64(request.Seed),
|
||||
EnableSearch: enableSearch,
|
||||
},
|
||||
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
|
||||
// TopP: request.TopP,
|
||||
// TopK: 50,
|
||||
// //Seed: 0,
|
||||
// //EnableSearch: false,
|
||||
//},
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||
return &AliEmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
@@ -130,21 +57,21 @@ func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingReque
|
||||
}
|
||||
}
|
||||
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var aliResponse AliEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
@@ -157,7 +84,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -165,16 +92,16 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
@@ -183,22 +110,22 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
||||
func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
|
||||
content, _ := json.Marshal(response.Output.Text)
|
||||
choice := OpenAITextResponseChoice{
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: response.Output.FinishReason,
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
Usage: Usage{
|
||||
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
@@ -207,25 +134,25 @@ func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aliResponse.Output.Text
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
response := ChatCompletionsStreamResponse{
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "ernie-bot",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var usage Usage
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -255,7 +182,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
service.SetEventStreamHeaders(c)
|
||||
lastResponseText := ""
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
@@ -288,28 +215,28 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var aliResponse AliChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
@@ -321,7 +248,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
52
relay/channel/api_request.go
Normal file
52
relay/channel/api_request.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
fullRequestURL, err := a.GetRequestURL(info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, req, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
resp, err := doRequest(c, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request failed: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
_ = c.Request.Body.Close()
|
||||
return resp, nil
|
||||
}
|
||||
92
relay/channel/baidu/adaptor.go
Normal file
92
relay/channel/baidu/adaptor.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var fullRequestURL string
|
||||
switch info.UpstreamModelName {
|
||||
case "ERNIE-Bot-4":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "ERNIE-Bot-8K":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
|
||||
case "ERNIE-Bot":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Speed":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
||||
case "ERNIE-Bot-turbo":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||
}
|
||||
var accessToken string
|
||||
var err error
|
||||
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fullRequestURL += "?access_token=" + accessToken
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
|
||||
return baiduEmbeddingRequest, nil
|
||||
default:
|
||||
baiduRequest := requestOpenAI2Baidu(*request)
|
||||
return baiduRequest, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = baiduStreamHandler(c, resp)
|
||||
} else {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = baiduEmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = baiduHandler(c, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
12
relay/channel/baidu/constants.go
Normal file
12
relay/channel/baidu/constants.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package baidu
|
||||
|
||||
var ModelList = []string{
|
||||
"ERNIE-Bot-4",
|
||||
"ERNIE-Bot-8K",
|
||||
"ERNIE-Bot",
|
||||
"ERNIE-Speed",
|
||||
"ERNIE-Bot-turbo",
|
||||
"Embedding-V1",
|
||||
}
|
||||
|
||||
var ChannelName = "baidu"
|
||||
71
relay/channel/baidu/dto.go
Normal file
71
relay/channel/baidu/dto.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"one-api/dto"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BaiduMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
type BaiduChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage dto.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type BaiduChatStreamResponse struct {
|
||||
BaiduChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []BaiduEmbeddingData `json:"data"`
|
||||
Usage dto.Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type BaiduAccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
}
|
||||
|
||||
type BaiduTokenResponse struct {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -16,91 +19,15 @@ import (
|
||||
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
|
||||
type BaiduTokenResponse struct {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
type BaiduMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type BaiduError struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
type BaiduChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
type BaiduChatStreamResponse struct {
|
||||
BaiduChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []BaiduEmbeddingData `json:"data"`
|
||||
Usage Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
type BaiduAccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
}
|
||||
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "user",
|
||||
Content: string(message.Content),
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: string(message.Content),
|
||||
})
|
||||
}
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
return &BaiduChatRequest{
|
||||
Messages: messages,
|
||||
@@ -108,57 +35,57 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
|
||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||
content, _ := json.Marshal(response.Result)
|
||||
choice := OpenAITextResponseChoice{
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: response.Id,
|
||||
Object: "chat.completion",
|
||||
Created: response.Created,
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||
Usage: response.Usage,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &stopFinishReason
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
response := ChatCompletionsStreamResponse{
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: baiduResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: baiduResponse.Created,
|
||||
Model: "ernie-bot",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||
return &BaiduEmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
||||
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
||||
Model: "baidu-embedding",
|
||||
Usage: response.Usage,
|
||||
}
|
||||
for _, item := range response.Data {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||
Object: item.Object,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
@@ -167,8 +94,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var usage Usage
|
||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -195,7 +122,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
@@ -225,28 +152,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var baiduResponse BaiduChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
@@ -258,7 +185,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -266,23 +193,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var baiduResponse BaiduEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
@@ -294,7 +221,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -337,7 +264,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
res, err := impatientHTTPClient.Do(req)
|
||||
res, err := service.GetImpatientHttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
81
relay/channel/claude/adaptor.go
Normal file
81
relay/channel/claude/adaptor.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
RequestModeCompletion = 1
|
||||
RequestModeMessage = 2
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
} else {
|
||||
a.RequestMode = RequestModeCompletion
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if a.RequestMode == RequestModeMessage {
|
||||
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
|
||||
} else {
|
||||
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("x-api-key", info.ApiKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if a.RequestMode == RequestModeCompletion {
|
||||
return requestOpenAI2ClaudeComplete(*request), nil
|
||||
} else {
|
||||
return requestOpenAI2ClaudeMessage(*request)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
|
||||
} else {
|
||||
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
13
relay/channel/claude/constants.go
Normal file
13
relay/channel/claude/constants.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package claude
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-instant-1.2",
|
||||
"claude-2",
|
||||
"claude-2.0",
|
||||
"claude-2.1",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
68
relay/channel/claude/dto.go
Normal file
68
relay/channel/claude/dto.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package claude
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeMediaMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeMessageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type ClaudeMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Content []ClaudeMediaMessage `json:"content"`
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
Index int `json:"index"` // stream only
|
||||
Delta *ClaudeMediaMessage `json:"delta"` // stream only
|
||||
Message *ClaudeResponse `json:"message"` // stream only: message_start
|
||||
}
|
||||
|
||||
//type ClaudeResponseChoice struct {
|
||||
// Index int `json:"index"`
|
||||
// Type string `json:"type"`
|
||||
//}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
344
relay/channel/claude/relay-claude.go
Normal file
344
relay/channel/claude/relay-claude.go
Normal file
@@ -0,0 +1,344 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func stopReasonClaude2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
}
|
||||
|
||||
func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
MaxTokensToSample: textRequest.MaxTokens,
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
}
|
||||
if claudeRequest.MaxTokensToSample == 0 {
|
||||
claudeRequest.MaxTokensToSample = 1000000
|
||||
}
|
||||
prompt := ""
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "user" {
|
||||
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
|
||||
} else if message.Role == "assistant" {
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||
} else if message.Role == "system" {
|
||||
if prompt == "" {
|
||||
prompt = message.StringContent()
|
||||
}
|
||||
}
|
||||
}
|
||||
prompt += "\n\nAssistant:"
|
||||
claudeRequest.Prompt = prompt
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
}
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = 4096
|
||||
}
|
||||
claudeMessages := make([]ClaudeMessage, 0)
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "system" {
|
||||
claudeRequest.System = message.StringContent()
|
||||
} else {
|
||||
claudeMessage := ClaudeMessage{
|
||||
Role: message.Role,
|
||||
}
|
||||
if message.IsStringContent() {
|
||||
claudeMessage.Content = message.StringContent()
|
||||
} else {
|
||||
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
|
||||
for _, mediaMessage := range message.ParseContent() {
|
||||
claudeMediaMessage := ClaudeMediaMessage{
|
||||
Type: mediaMessage.Type,
|
||||
}
|
||||
if mediaMessage.Type == "text" {
|
||||
claudeMediaMessage.Text = mediaMessage.Text
|
||||
} else {
|
||||
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
||||
claudeMediaMessage.Type = "image"
|
||||
claudeMediaMessage.Source = &ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
}
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
|
||||
claudeMediaMessage.Source.MediaType = mimeType
|
||||
claudeMediaMessage.Source.Data = data
|
||||
} else {
|
||||
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claudeMediaMessage.Source.MediaType = "image/" + format
|
||||
claudeMediaMessage.Source.Data = base64String
|
||||
}
|
||||
}
|
||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||
}
|
||||
claudeMessage.Content = claudeMediaMessages
|
||||
}
|
||||
claudeMessages = append(claudeMessages, claudeMessage)
|
||||
}
|
||||
}
|
||||
claudeRequest.Prompt = ""
|
||||
claudeRequest.Messages = claudeMessages
|
||||
|
||||
return &claudeRequest, nil
|
||||
}
|
||||
|
||||
func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
var claudeUsage *ClaudeUsage
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if reqMode == RequestModeCompletion {
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
response.Id = claudeResponse.Message.Id
|
||||
response.Model = claudeResponse.Message.Model
|
||||
claudeUsage = &claudeResponse.Message.Usage
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
choice.Index = claudeResponse.Index
|
||||
choice.Delta.Content = claudeResponse.Delta.Text
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
claudeUsage = &claudeResponse.Usage
|
||||
}
|
||||
}
|
||||
if claudeUsage == nil {
|
||||
claudeUsage = &ClaudeUsage{}
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
return &response, claudeUsage
|
||||
}
|
||||
|
||||
func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
||||
choices := make([]dto.OpenAITextResponseChoice, 0)
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
}
|
||||
if reqMode == RequestModeCompletion {
|
||||
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
choices = append(choices, choice)
|
||||
} else {
|
||||
fullTextResponse.Id = claudeResponse.Id
|
||||
for i, message := range claudeResponse.Content {
|
||||
content, _ := json.Marshal(message.Text)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
choices = append(choices, choice)
|
||||
}
|
||||
}
|
||||
|
||||
fullTextResponse.Choices = choices
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
var usage *dto.Usage
|
||||
usage = &dto.Usage{}
|
||||
responseText := ""
|
||||
createdTime := common.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if !strings.HasPrefix(data, "data: ") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var claudeResponse ClaudeResponse
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
if requestMode == RequestModeCompletion {
|
||||
responseText += claudeResponse.Completion
|
||||
responseId = response.Id
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
responseId = claudeResponse.Message.Id
|
||||
modelName = claudeResponse.Message.Model
|
||||
usage.PromptTokens = claudeUsage.InputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
responseText += claudeResponse.Delta.Text
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
usage.CompletionTokens = claudeUsage.OutputTokens
|
||||
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
//response.Id = responseId
|
||||
response.Id = responseId
|
||||
response.Created = createdTime
|
||||
response.Model = modelName
|
||||
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
||||
} else {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var claudeResponse ClaudeResponse
|
||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = completionTokens
|
||||
usage.TotalTokens = promptTokens + completionTokens
|
||||
} else {
|
||||
usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
67
relay/channel/gemini/adaptor.go
Normal file
67
relay/channel/gemini/adaptor.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
version := "v1"
|
||||
if info.ApiVersion != "" {
|
||||
version = info.ApiVersion
|
||||
}
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return CovertGemini2OpenAI(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = geminiChatStreamHandler(c, resp)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
12
relay/channel/gemini/constant.go
Normal file
12
relay/channel/gemini/constant.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package gemini
|
||||
|
||||
const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
|
||||
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
|
||||
}
|
||||
|
||||
var ChannelName = "google gemini"
|
||||
62
relay/channel/gemini/dto.go
Normal file
62
relay/channel/gemini/dto.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package gemini
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type GeminiChatTools struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type GeminiChatPromptFeedback struct {
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -7,57 +7,16 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type GeminiChatTools struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []GeminiChatSafetySettings{
|
||||
@@ -97,7 +56,7 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
Role: message.Role,
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: string(message.Content),
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -106,16 +65,16 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
|
||||
if part.Type == ContentTypeText {
|
||||
if part.Type == dto.ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == ContentTypeImageURL {
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
if imageNum > GeminiVisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
|
||||
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
@@ -154,11 +113,6 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
}
|
||||
|
||||
func (g *GeminiChatResponse) GetResponseText() string {
|
||||
if g == nil {
|
||||
return ""
|
||||
@@ -169,41 +123,25 @@ func (g *GeminiChatResponse) GetResponseText() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type GeminiChatPromptFeedback struct {
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
content, _ := json.Marshal("")
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := OpenAITextResponseChoice{
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: Message{
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: stopFinishReason,
|
||||
FinishReason: relaycommon.StopFinishReason,
|
||||
}
|
||||
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
|
||||
choice.Message.Content = content
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
@@ -211,18 +149,18 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
choice.FinishReason = &stopFinishReason
|
||||
var response ChatCompletionsStreamResponse
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "gemini"
|
||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
||||
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
@@ -252,7 +190,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
@@ -264,14 +202,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
|
||||
var dummy dummyStruct
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
response := ChatCompletionsStreamResponse{
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
@@ -287,28 +225,28 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
@@ -318,8 +256,8 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
|
||||
usage := Usage{
|
||||
completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
@@ -327,7 +265,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
9
relay/channel/lingyiwanwu/constrants.go
Normal file
9
relay/channel/lingyiwanwu/constrants.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package lingyiwanwu
|
||||
|
||||
// https://platform.lingyiwanwu.com/docs
|
||||
|
||||
var ModelList = []string{
|
||||
"yi-34b-chat-0205",
|
||||
"yi-34b-chat-200k",
|
||||
"yi-vl-plus",
|
||||
}
|
||||
7
relay/channel/moonshot/constants.go
Normal file
7
relay/channel/moonshot/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package moonshot
|
||||
|
||||
var ModelList = []string{
|
||||
"moonshot-v1-8k",
|
||||
"moonshot-v1-32k",
|
||||
"moonshot-v1-128k",
|
||||
}
|
||||
73
relay/channel/ollama/adaptor.go
Normal file
73
relay/channel/ollama/adaptor.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return info.BaseUrl + "/api/embeddings", nil
|
||||
default:
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return requestOpenAI2Embeddings(*request), nil
|
||||
default:
|
||||
return requestOpenAI2Ollama(*request), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
5
relay/channel/ollama/constants.go
Normal file
5
relay/channel/ollama/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package ollama
|
||||
|
||||
var ModelList []string
|
||||
|
||||
var ChannelName = "ollama"
|
||||
26
relay/channel/ollama/dto.go
Normal file
26
relay/channel/ollama/dto.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package ollama
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Topp float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding,omitempty"`
|
||||
}
|
||||
|
||||
//type OllamaOptions struct {
|
||||
//}
|
||||
108
relay/channel/ollama/relay-ollama.go
Normal file
108
relay/channel/ollama/relay-ollama.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
messages = append(messages, dto.Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
str, ok := request.Stop.(string)
|
||||
var Stop []string
|
||||
if ok {
|
||||
Stop = []string{str}
|
||||
} else {
|
||||
Stop, _ = request.Stop.([]string)
|
||||
}
|
||||
return &OllamaRequest{
|
||||
Model: request.Model,
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
Temperature: request.Temperature,
|
||||
Seed: request.Seed,
|
||||
Topp: request.TopP,
|
||||
TopK: request.TopK,
|
||||
Stop: Stop,
|
||||
}
|
||||
}
|
||||
|
||||
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
|
||||
return &OllamaEmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Prompt: request.Input,
|
||||
}
|
||||
}
|
||||
|
||||
func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||
data = append(data, dto.OpenAIEmbeddingResponseItem{
|
||||
Embedding: ollamaEmbeddingResponse.Embedding,
|
||||
Object: "embedding",
|
||||
})
|
||||
usage := &dto.Usage{
|
||||
TotalTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
Model: model,
|
||||
Usage: *usage,
|
||||
}
|
||||
doResponseBody, err := json.Marshal(embeddingResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
// Copy headers
|
||||
for k, v := range resp.Header {
|
||||
// 删除任何现有的相同头部,以防止重复添加头部
|
||||
c.Writer.Header().Del(k)
|
||||
for _, vv := range v {
|
||||
c.Writer.Header().Add(k, vv)
|
||||
}
|
||||
}
|
||||
// reset content length
|
||||
c.Writer.Header().Del("Content-Length")
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
98
relay/channel/openai/adaptor.go
Normal file
98
relay/channel/openai/adaptor.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ai360"
|
||||
"one-api/relay/channel/lingyiwanwu"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
model_ := info.UpstreamModelName
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
||||
req.Header.Set("OpenAI-Organization", info.Organization)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
//if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
// req.Header.Set("X-Title", "One API")
|
||||
//}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
switch a.ChannelType {
|
||||
case common.ChannelType360:
|
||||
return ai360.ModelList
|
||||
case common.ChannelTypeMoonshot:
|
||||
return moonshot.ModelList
|
||||
case common.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ModelList
|
||||
default:
|
||||
return ModelList
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
21
relay/channel/openai/constant.go
Normal file
21
relay/channel/openai/constant.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package openai
|
||||
|
||||
var ModelList = []string{
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
"text-davinci-edit-001",
|
||||
"davinci-002", "babbage-002",
|
||||
"dall-e-2", "dall-e-3",
|
||||
"whisper-1",
|
||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||
}
|
||||
|
||||
var ChannelName = "openai"
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -8,12 +8,16 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||
var responseTextBuilder strings.Builder
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
@@ -33,11 +37,10 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
defer close(stopChan)
|
||||
defer close(dataChan)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
go func() {
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
var streamItems []string
|
||||
var streamItems []string // store stream items
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
@@ -54,28 +57,46 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
}
|
||||
streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
||||
switch relayMode {
|
||||
case RelayModeChatCompletions:
|
||||
var streamResponses []ChatCompletionsStreamResponseSimple
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return // just ignore the error
|
||||
}
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponseSimple
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
case RelayModeCompletions:
|
||||
var streamResponses []CompletionsStreamResponse
|
||||
case relayconstant.RelayModeCompletions:
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return // just ignore the error
|
||||
}
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -85,7 +106,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
}
|
||||
common.SafeSend(stopChan, true)
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
@@ -102,30 +123,30 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
wg.Wait()
|
||||
return nil, responseTextBuilder.String()
|
||||
}
|
||||
|
||||
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var textResponse TextResponse
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var simpleResponse dto.SimpleResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
err = json.Unmarshal(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
if simpleResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: simpleResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
@@ -140,23 +161,24 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if textResponse.Usage.TotalTokens == 0 {
|
||||
if simpleResponse.Usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range textResponse.Choices {
|
||||
completionTokens += countTokenText(string(choice.Message.Content), model)
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
textResponse.Usage = Usage{
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
return nil, &textResponse.Usage
|
||||
return nil, &simpleResponse.Usage
|
||||
}
|
||||
59
relay/channel/palm/adaptor.go
Normal file
59
relay/channel/palm/adaptor.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package palm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
7
relay/channel/palm/constants.go
Normal file
7
relay/channel/palm/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package palm
|
||||
|
||||
var ModelList = []string{
|
||||
"PaLM-2",
|
||||
}
|
||||
|
||||
var ChannelName = "google palm"
|
||||
38
relay/channel/palm/dto.go
Normal file
38
relay/channel/palm/dto.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package palm
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type PaLMChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PaLMFilter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PaLMPrompt struct {
|
||||
Messages []PaLMChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type PaLMChatRequest struct {
|
||||
Prompt PaLMPrompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK uint `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PaLMChatResponse struct {
|
||||
Candidates []PaLMChatMessage `json:"candidates"`
|
||||
Messages []dto.Message `json:"messages"`
|
||||
Filters []PaLMFilter `json:"filters"`
|
||||
Error PaLMError `json:"error"`
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -7,47 +7,15 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||
|
||||
type PaLMChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PaLMFilter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PaLMPrompt struct {
|
||||
Messages []PaLMChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type PaLMChatRequest struct {
|
||||
Prompt PaLMPrompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK uint `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PaLMChatResponse struct {
|
||||
Candidates []PaLMChatMessage `json:"candidates"`
|
||||
Messages []Message `json:"messages"`
|
||||
Filters []PaLMFilter `json:"filters"`
|
||||
Error PaLMError `json:"error"`
|
||||
}
|
||||
|
||||
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
palmRequest := PaLMChatRequest{
|
||||
Prompt: PaLMPrompt{
|
||||
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
||||
@@ -59,7 +27,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
palmMessage := PaLMChatMessage{
|
||||
Content: string(message.Content),
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
palmMessage.Author = "0"
|
||||
@@ -71,15 +39,15 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
content, _ := json.Marshal(candidate.Content)
|
||||
choice := OpenAITextResponseChoice{
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: Message{
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
@@ -90,20 +58,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
}
|
||||
choice.FinishReason = &stopFinishReason
|
||||
var response ChatCompletionsStreamResponse
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "palm2"
|
||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
||||
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
@@ -144,7 +112,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
|
||||
dataChan <- string(jsonResponse)
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
@@ -157,28 +125,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
@@ -188,8 +156,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
|
||||
usage := Usage{
|
||||
completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
@@ -197,7 +165,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
63
relay/channel/perplexity/adaptor.go
Normal file
63
relay/channel/perplexity/adaptor.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package perplexity
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if request.TopP >= 1 {
|
||||
request.TopP = 0.99
|
||||
}
|
||||
return requestOpenAI2Perplexity(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
7
relay/channel/perplexity/constants.go
Normal file
7
relay/channel/perplexity/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package perplexity
|
||||
|
||||
var ModelList = []string{
|
||||
"sonar-small-chat", "sonar-small-online", "sonar-medium-chat", "sonar-medium-online", "mistral-7b-instruct", "mixtral-8x7b-instruct",
|
||||
}
|
||||
|
||||
var ChannelName = "perplexity"
|
||||
21
relay/channel/perplexity/relay-perplexity.go
Normal file
21
relay/channel/perplexity/relay-perplexity.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package perplexity
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
messages = append(messages, dto.Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
return &dto.GeneralOpenAIRequest{
|
||||
Model: request.Model,
|
||||
Stream: request.Stream,
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxTokens: request.MaxTokens,
|
||||
}
|
||||
}
|
||||
73
relay/channel/tencent/adaptor.go
Normal file
73
relay/channel/tencent/adaptor.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
Sign string
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", a.Sign)
|
||||
req.Header.Set("X-TC-Action", info.UpstreamModelName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tencentRequest := requestOpenAI2Tencent(*request)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
// we have to calculate the sign here
|
||||
a.Sign = getTencentSign(*tencentRequest, secretKey)
|
||||
return tencentRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = tencentStreamHandler(c, resp)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = tencentHandler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
9
relay/channel/tencent/constants.go
Normal file
9
relay/channel/tencent/constants.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package tencent
|
||||
|
||||
var ModelList = []string{
|
||||
"ChatPro",
|
||||
"ChatStd",
|
||||
"hunyuan",
|
||||
}
|
||||
|
||||
var ChannelName = "tencent"
|
||||
61
relay/channel/tencent/dto.go
Normal file
61
relay/channel/tencent/dto.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package tencent
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type TencentMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TencentChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []TencentMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type TencentError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type TencentUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type TencentResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type TencentChatResponse struct {
|
||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage dto.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -12,6 +12,9 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -19,72 +22,14 @@ import (
|
||||
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
|
||||
type TencentMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TencentChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []TencentMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type TencentError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type TencentUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type TencentResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type TencentChatResponse struct {
|
||||
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage Usage `json:"usage,omitempty"` // token 数量
|
||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
||||
|
||||
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
||||
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
|
||||
messages := make([]TencentMessage, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "user",
|
||||
Content: string(message.Content),
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, TencentMessage{
|
||||
Role: "assistant",
|
||||
@@ -93,7 +38,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
||||
continue
|
||||
}
|
||||
messages = append(messages, TencentMessage{
|
||||
Content: string(message.Content),
|
||||
Content: message.StringContent(),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
@@ -112,17 +57,17 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: response.Usage,
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
content, _ := json.Marshal(response.Choices[0].Messages.Content)
|
||||
choice := OpenAITextResponseChoice{
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
@@ -133,24 +78,24 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
|
||||
response := ChatCompletionsStreamResponse{
|
||||
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "tencent-hunyuan",
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &stopFinishReason
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
@@ -181,7 +126,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
@@ -209,28 +154,28 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWith
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var TencentResponse TencentChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
},
|
||||
@@ -240,7 +185,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatus
|
||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
68
relay/channel/xunfei/adaptor.go
Normal file
68
relay/channel/xunfei/adaptor.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
a.request = request
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
dummyResp.StatusCode = http.StatusOK
|
||||
return dummyResp, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
splits := strings.Split(info.ApiKey, "|")
|
||||
if len(splits) != 3 {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
if a.request == nil {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
||||
}
|
||||
if info.IsStream {
|
||||
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
11
relay/channel/xunfei/constants.go
Normal file
11
relay/channel/xunfei/constants.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package xunfei
|
||||
|
||||
var ModelList = []string{
|
||||
"SparkDesk",
|
||||
"SparkDesk-v1.1",
|
||||
"SparkDesk-v2.1",
|
||||
"SparkDesk-v3.1",
|
||||
"SparkDesk-v3.5",
|
||||
}
|
||||
|
||||
var ChannelName = "xunfei"
|
||||
59
relay/channel/xunfei/dto.go
Normal file
59
relay/channel/xunfei/dto.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package xunfei
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type XunfeiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type XunfeiChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
Message struct {
|
||||
Text []XunfeiMessage `json:"text"`
|
||||
} `json:"message"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponseTextItem struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponse struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []XunfeiChatResponseTextItem `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
//Text struct {
|
||||
// QuestionTokens string `json:"question_tokens"`
|
||||
// PromptTokens string `json:"prompt_tokens"`
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text dto.Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user