Compare commits

..

94 Commits

Author SHA1 Message Date
wozulong
8b55116563 update build files
Signed-off-by: wozulong <>
2024-03-20 18:31:22 +08:00
wozulong
f35e63e3f3 limit 'LINUX DO' trust level now available
Signed-off-by: wozulong <>
2024-03-20 16:54:38 +08:00
wozulong
17c409de23 update gh action
Signed-off-by: wozulong <>
2024-03-20 14:11:21 +08:00
wozulong
e4753e7411 synced with upstream
Signed-off-by: wozulong <>
2024-03-20 13:52:10 +08:00
CaIon
bec21ade9d Update README.md 2024-03-18 20:44:20 +08:00
CaIon
4c4e087060 fix: api type error 2024-03-18 19:55:19 +08:00
CaIon
beadb98a8c feat: support perplexity 2024-03-18 18:09:41 +08:00
CaIon
615d109d70 Merge remote-tracking branch 'origin/main' 2024-03-18 18:08:09 +08:00
Calcium-Ion
0da49fa446 Merge pull request #125 from wozulong/main
fix: openai_organization not working
2024-03-18 18:07:34 +08:00
CaIon
578b5f6536 feat: support perplexity 2024-03-18 18:00:19 +08:00
我秦始皇
f63ad9c03c fix: make the 'openai_organization' parameter actually work. 2024-03-18 02:19:59 -07:00
我秦始皇
0fb98e44a7 fix: make the 'openai_organization' parameter actually work. 2024-03-18 02:11:21 -07:00
CaIon
652bb4a53c Update README.md 2024-03-18 00:16:32 +08:00
CaIon
a26b9a9bff Update Midjourney.md 2024-03-18 00:13:37 +08:00
CaIon
f82aa956bd feat: midjourneys表添加索引 2024-03-17 22:56:17 +08:00
CaIon
b8c053c37f fix: chatglm top_p error (close #124) 2024-03-17 22:55:29 +08:00
CaIon
2581b37394 fix: 侧边栏聚焦问题 (close #123) 2024-03-17 16:37:19 +08:00
CaIon
f65477d054 fix: 修复日志翻页问题 (close #122) 2024-03-17 16:18:45 +08:00
CaIon
4b952b8582 feat: print midjourney-proxy request error 2024-03-16 17:25:10 +08:00
CaIon
a6ba1d01d9 feat: 添加回调未开启提示 2024-03-15 22:15:16 +08:00
CaIon
2d2fec24d0 chore: update dependency 2024-03-15 18:35:15 +08:00
CaIon
d34b55c154 chore: reformat code 2024-03-15 16:05:33 +08:00
CaIon
25ec99913b feat: Add logs page-size 2024-03-15 15:07:14 +08:00
CaIon
7f22d58574 feat: Only update weight and priority when blur 2024-03-15 14:53:33 +08:00
CaIon
dfdeadf1a5 feat: Remember and automatically set page-size (close #118) 2024-03-15 14:45:52 +08:00
paderlol
9adefa80b9 Optimizing Docker image builds (#1)
Co-authored-by: pader.zhang <pader.zhang@starlight-sms.com>
2024-03-14 23:10:05 -07:00
CaIon
6ab1b3a524 fix: fix image-seed error 2024-03-14 22:17:03 +08:00
CaIon
4917e5a92f feat: 允许开关mj回调 2024-03-14 21:21:37 +08:00
wozulong
4ce2381182 update issue template
Signed-off-by: wozulong <>
2024-03-14 20:02:22 +08:00
我秦始皇
62afc21ea5 Merge branch 'Calcium-Ion:main' into main 2024-03-14 03:56:31 -07:00
wozulong
7ddb7c586d 1. add LINUX DO oauth
2. fix oauth reg aff issue

Signed-off-by: wozulong <>
2024-03-14 18:53:54 +08:00
Calcium-Ion
f62dcbf669 Merge pull request #114 from Calcium-Ion/midjourney-proxy-plus
feat: support midjourney-proxy-plus
2024-03-14 18:49:30 +08:00
CaIon
299911d4cd fix: 修复epay可能重复回调的问题 2024-03-14 18:33:35 +08:00
CaIon
84e0544604 refactor: 修改超时时间 2024-03-14 18:19:22 +08:00
CaIon
2786a6b539 Update README.md 2024-03-14 18:10:09 +08:00
CaIon
9b5353a81a feat: support InsightFace (close #60) 2024-03-14 18:09:57 +08:00
CaIon
bc5a54df59 feat: support image-seed (close #86) 2024-03-14 18:09:56 +08:00
CaIon
d704902b70 feat: 兼容自定义变焦,完善modal操作 2024-03-14 16:42:37 +08:00
CaIon
614220a0fb feat: 超过一小时的任务自动失败 2024-03-14 15:16:36 +08:00
CaIon
98c1f66d61 feat: support Anthropic Claude 3.0 Haiku (close #116) 2024-03-14 14:28:21 +08:00
CaIon
a77fbc0fa2 fix: reroll action error 2024-03-14 00:43:32 +08:00
CaIon
44361d75e8 fix: "Inpaint" code error 2024-03-13 23:17:12 +08:00
CaIon
9b2e5c2978 refactor: remove consumeQuota 2024-03-13 22:30:10 +08:00
CaIon
d3399d68f6 fix: fix typo 2024-03-13 22:24:02 +08:00
CaIon
3d10c9f090 feat: 将操作拆分成单独的模型 2024-03-13 21:19:48 +08:00
CaIon
d5ffaf2502 feat: 操作细分 2024-03-13 18:26:16 +08:00
CaIon
2ad591411e feat: support shorten 2024-03-13 17:46:34 +08:00
CaIon
728dbed28d feat: 兼容变焦功能 2024-03-13 16:29:27 +08:00
CaIon
fd3a41bacb feat: 请求超时处理 2024-03-13 16:19:22 +08:00
CaIon
37c0c8ebdd feat: 初步兼容midjourney-proxy-plus 2024-03-13 15:37:01 +08:00
CaIon
95d8059c90 Merge remote-tracking branch 'origin/main' 2024-03-12 18:18:57 +08:00
CaIon
c012306400 feat: add parameter top_k 2024-03-12 18:18:47 +08:00
Calcium-Ion
1e1a53e4d2 Merge pull request #113 from xyspg/main
fix: layout issues
2024-03-12 17:36:02 +08:00
Kang Jiaming
30d48ea473 fix: layout issues 2024-03-12 12:49:36 +08:00
CaIon
9e5c636490 Update README.md 2024-03-12 02:39:10 +08:00
CaIon
d53d3386e9 feat: support ollama (close #112) 2024-03-12 02:36:39 +08:00
CaIon
c3a01decd8 chore: drop idx_channels_key on start 2024-03-12 00:35:12 +08:00
CaIon
c29680b301 feat: 前端美化 2024-03-11 21:08:51 +08:00
CaIon
ac7407ce9c Update README.md 2024-03-11 19:27:54 +08:00
Calcium-Ion
3ea061a820 Merge pull request #98 from h1xy/main
Update README.md
2024-03-11 19:23:42 +08:00
CaIon
d4df1960b2 fix: 修复模型映射功能失效 (close #105) 2024-03-11 18:28:51 +08:00
CaIon
97dd80541b fix: 请求阿里通义千问异常 (close #108) 2024-03-11 14:35:46 +08:00
CaIon
bf241b218f fix: 删除用户改为软删除 (close #107) 2024-03-11 14:13:45 +08:00
CaIon
7ab6c6c303 fix: 修复claude渠道流模式计费可能异常 2024-03-09 13:25:47 +08:00
CaIon
8fe8340b6e fix: fix gemini channel test 2024-03-08 21:38:51 +08:00
CaIon
26ef906c61 fix: fix claude channel test 2024-03-08 21:38:43 +08:00
CaIon
655dfe0d09 feat: update claude default model ratio 2024-03-08 21:17:32 +08:00
CaIon
f43b268520 fix: fix claude 3 request missing the 'max_token' field 2024-03-08 21:16:12 +08:00
CaIon
37113c0e96 perf: 优化其他设置显示效果 2024-03-08 20:35:15 +08:00
CaIon
3c3c53051d fix: fix gemini 2024-03-08 20:29:04 +08:00
CaIon
6a83c8ad86 fix: 隐藏无用按钮 2024-03-08 20:14:49 +08:00
CaIon
0640dd81fd fix: fix Azure channel test (close #101) 2024-03-08 20:06:40 +08:00
CaIon
cb50fcaffe Update README.md 2024-03-08 19:51:52 +08:00
Calcium-Ion
eca48268b2 Merge pull request #103 from Calcium-Ion/dev
feat: support Claude 3
2024-03-08 19:47:24 +08:00
CaIon
4a0af1ea3c feat: support Claude 3 2024-03-08 19:43:33 +08:00
CaIon
c2965eb835 feat: support claude3 not stream 2024-03-08 18:26:18 +08:00
Calcium-Ion
4186880e4c Merge pull request #96 from QuentinHsu/refactor-other-setting
refactor(OtherSetting): change UI, improve interaction
2024-03-08 16:52:45 +08:00
1808837298@qq.com
280c63e1d4 feat: 添加充值记录按钮 2024-03-07 22:41:04 +08:00
h1xy
9e0a54e943 Update README.md
增加telegram授权登录的配置步骤。
2024-03-06 23:37:32 +08:00
1808837298@qq.com
0e06be8c3e fix: fix bug 2024-03-06 19:53:31 +08:00
QuentinHsu
931d22c96f perf(OtherSetting): code logic 2024-03-06 18:26:22 +08:00
QuentinHsu
6413bf342a refactor(OtherSetting): change UI, improve interaction 2024-03-06 17:57:53 +08:00
1808837298@qq.com
626217fbd4 fix: 修复流模式错误扣费的问题 (close #95) 2024-03-06 17:41:55 +08:00
1808837298@qq.com
9de2d21e1a fix: 恢复微信公众号扫码功能 (close #7) 2024-03-06 16:53:21 +08:00
1808837298@qq.com
3ab4f145db feat: 支持部分渠道的system角色 (close #89) 2024-03-06 14:16:04 +08:00
Calcium-Ion
da73dca9a7 Merge pull request #81 from QuentinHsu/fix-display-home-content
fix: the content displayed on the homepage is incorrect
2024-03-05 23:16:48 +08:00
Calcium-Ion
415d296171 Merge branch 'main' into fix-display-home-content 2024-03-05 23:16:35 +08:00
Calcium-Ion
d5c5c30312 Merge pull request #91 from QuentinHsu/fix-model-name-not-trimmed
fix: model name not trimmed
2024-03-05 23:08:19 +08:00
1808837298@qq.com
fb39b6b30e fix: Fix the issue of 'unknown relay mode' when making an embedding request (close #93) 2024-03-05 23:04:57 +08:00
QuentinHsu
92c1ed7f1d perf: prompt when the name of the custom model input already exists 2024-03-05 12:18:30 +08:00
QuentinHsu
d160736a49 fix: model names must not contain spaces at both ends 2024-03-05 12:16:13 +08:00
1808837298@qq.com
fe7f42fc2e feat: add "/api/status/test" 2024-03-04 19:32:59 +08:00
QuentinHsu
5cb933a278 fix: restore the set of StatusContext 2024-03-04 16:39:42 +08:00
QuentinHsu
ac64fd26ad fix: the content displayed on the homepage is incorrect 2024-03-02 15:23:38 +08:00
118 changed files with 19391 additions and 7287 deletions

View File

@@ -1,5 +1,5 @@
blank_issues_enabled: false
contact_links:
- name: 项目群聊
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
about: QQ 群629454374
- name: 交流社区
url: https://linux.do
about: 项目交流社区

View File

@@ -43,7 +43,7 @@ jobs:
uses: docker/metadata-action@v4
with:
images: |
calciumion/new-api
pengzhile/new-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images

View File

@@ -49,7 +49,7 @@ jobs:
uses: docker/metadata-action@v4
with:
images: |
calciumion/new-api
pengzhile/new-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images

View File

@@ -1,32 +1,29 @@
FROM node:16 as builder
FROM node:16-slim as builder
WORKDIR /build
COPY web/package.json .
RUN npm install
COPY web/yarn.lock .
RUN yarn install
COPY ./web .
COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) yarn build
FROM golang:1.19-alpine AS builder2
RUN apk add --no-cache build-base
ENV GO111MODULE=on \
CGO_ENABLED=1 \
GOOS=linux
WORKDIR /build
ADD go.mod go.sum ./
RUN go mod download
#ADD go.mod go.sum ./
COPY . .
COPY --from=builder /build/build ./web/build
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
RUN go mod tidy \
&& go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine
RUN apk update \
&& apk upgrade \
&& apk add --no-cache ca-certificates tzdata \
&& update-ca-certificates 2>/dev/null || true
COPY --from=builder2 /build/one-api /
EXPOSE 3000
WORKDIR /data

View File

@@ -7,7 +7,7 @@ all: build-frontend start-backend
build-frontend:
@echo "Building frontend..."
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build npm run build
@cd $(FRONTEND_DIR) && yarn install && DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) yarn build
start-backend:
@echo "Starting backend dev server..."

View File

@@ -2,290 +2,66 @@
**简介**:Midjourney Proxy API文档
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
## 模型列表
### midjourney-proxy支持
- mj_imagine (绘图)
- mj_variation (变换)
- mj_reroll (重绘)
- mj_blend (混合)
- mj_upscale (放大)
- mj_describe (图生文)
### 仅midjourney-proxy-plus支持
- mj_zoom (比例变焦)
- mj_shorten (提示词缩短)
- mj_modal (窗口提交局部重绘和自定义比例变焦必须和mj_modal一同添加)
- mj_inpaint (局部重绘提交必须和mj_modal一同添加)
- mj_custom_zoom (自定义比例变焦必须和mj_modal一同添加)
- mj_high_variation (强变换)
- mj_low_variation (弱变换)
- mj_pan (平移)
- swap_face (换脸)
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
```json
{
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05
"mj_upscale": 0.05,
"swap_face": 0.05
}
```
其中mj_inpaint和mj_custom_zoom的价格设置为0是因为这两个模型需要搭配mj_modal使用所以价格由mj_modal决定。
## 渠道设置
### 对接 midjourney-proxy
1. 部署Midjourney-Proxy并配置好midjourney账号等强烈建议设置密钥[项目地址](https://github.com/novicezk/midjourney-proxy)
2. 在渠道管理中添加渠道渠道类型选择Midjourney Proxy模型选择midjourney
### 对接 midjourney-proxy(plus)
1.
部署Midjourney-Proxy并配置好midjourney账号等强烈建议设置密钥[项目地址](https://github.com/novicezk/midjourney-proxy)
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**如果是plus版本选择**Midjourney Proxy Plus**
,模型请参考上方模型列表
3. 地址填写midjourney-proxy部署的地址例如http://localhost:8080
4. 密钥填写midjourney-proxy的密钥如果没有设置密钥可以随便填
### 对接上游new api
1. 在渠道管理中添加渠道渠道类型选择Midjourney Proxy模型选择midjourney
2. 地址填写上游new api的地址例如http://localhost:8080
3. 密钥填写上游new api的密钥
## 任务提交
### 绘图变化
**接口地址**:`/mj/submit/change`
**请求方式**:`POST`
**请求数据类型**:`application/json`
**响应数据类型**:`*/*`
**接口描述**:
**请求示例**:
```javascript
{
"action"
:
"UPSCALE",
"index"
:
1,
"notifyHook"
:
"",
"state"
:
"",
"taskId"
:
"1320098173412546"
}
```
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------|
| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 |
| &emsp;&emsp;action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | |
| &emsp;&emsp;index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | |
| &emsp;&emsp;notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
| &emsp;&emsp;state | 自定义参数 | | false | string | |
| &emsp;&emsp;taskId | 任务ID | | true | string | |
**响应状态**:
| 状态码 | 说明 | schema |
|-----|--------------|--------|
| 200 | OK | 提交结果 |
| 201 | Created | |
| 401 | Unauthorized | |
| 403 | Forbidden | |
| 404 | Not Found | |
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
|-------------|-------------------------------------------|----------------|----------------|
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
| description | 描述 | string | |
| properties | 扩展字段 | object | |
| result | 任务ID | string | |
**响应示例**:
```javascript
{
"code"
:
1,
"description"
:
"提交成功",
"properties"
:
{
}
,
"result"
:
1320098173412546
}
```
### 提交Imagine任务
**接口地址**:`/mj/submit/imagine`
**请求方式**:`POST`
**请求数据类型**:`application/json`
**响应数据类型**:`*/*`
**接口描述**:
**请求示例**:
```javascript
{
"base64"
:
"",
"notifyHook"
:
"",
"prompt"
:
"Cat",
"state"
:
""
}
```
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|------------------------|-------------------------|------|-------|-------------|-------------|
| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 |
| &emsp;&emsp;base64 | 垫图base64 | | false | string | |
| &emsp;&emsp;notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
| &emsp;&emsp;prompt | 提示词 | | true | string | |
| &emsp;&emsp;state | 自定义参数 | | false | string | |
**响应状态**:
| 状态码 | 说明 | schema |
|-----|--------------|--------|
| 200 | OK | 提交结果 |
| 201 | Created | |
| 401 | Unauthorized | |
| 403 | Forbidden | |
| 404 | Not Found | |
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
|-------------|-------------------------------------------|----------------|----------------|
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
| description | 描述 | string | |
| properties | 扩展字段 | object | |
| result | 任务ID | string | |
**响应示例**:
```javascript
{
"code"
:
1,
"description"
:
"提交成功",
"properties"
:
{
}
,
"result"
:
1320098173412546
}
```
## 任务查询
### 指定ID获取任务
**接口地址**:`/mj/task/{id}/fetch`
**请求方式**:`GET`
**请求数据类型**:`application/x-www-form-urlencoded`
**响应数据类型**:`*/*`
**接口描述**:
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|------|------|------|-------|--------|--------|
| id | 任务ID | path | false | string | |
**响应状态**:
| 状态码 | 说明 | schema |
|-----|--------------|--------|
| 200 | OK | 任务 |
| 401 | Unauthorized | |
| 403 | Forbidden | |
| 404 | Not Found | |
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
|-------------|----------------------------------------------------------|----------------|----------------|
| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | |
| description | 任务描述 | string | |
| failReason | 失败原因 | string | |
| finishTime | 结束时间 | integer(int64) | integer(int64) |
| id | 任务ID | string | |
| imageUrl | 图片url | string | |
| progress | 任务进度 | string | |
| prompt | 提示词 | string | |
| promptEn | 提示词-英文 | string | |
| startTime | 开始执行时间 | integer(int64) | integer(int64) |
| state | 自定义参数 | string | |
| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | |
| submitTime | 提交时间 | integer(int64) | integer(int64) |
**响应示例**:
```javascript
{
"action"
:
"",
"description"
:
"",
"failReason"
:
"",
"finishTime"
:
0,
"id"
:
"",
"imageUrl"
:
"",
"progress"
:
"",
"prompt"
:
"",
"promptEn"
:
"",
"startTime"
:
0,
"state"
:
"",
"status"
:
"",
"submitTime"
:
0
}
```
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
2. 地址填写上游new api的地址例如http://localhost:3000
3. 密钥填写上游new api的密钥

View File

@@ -18,7 +18,7 @@
此分叉版本的主要变更如下:
1. 全新的UI界面部分界面还待更新
2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持[对接文档](Midjourney.md),支持的接口如下:
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
@@ -26,6 +26,11 @@
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
+ [x] /mj/submit/action 仅midjourney-proxy-plus支持下同
+ [x] /mj/submit/modal
+ [x] /mj/submit/shorten
+ [x] /mj/task/{id}/image-seed
+ [x] /mj/insight-face/swap InsightFace
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付
4. 支持用key查询使用额度:
@@ -37,12 +42,19 @@
9. 支持渠道**加权随机**
10. 数据看板
11. 可设置令牌能调用的模型
12. 支持Telegram授权登录
12. 支持Telegram授权登录
1. 系统设置-配置登录注册-允许通过Telegram登录
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
3. 选择你的bot然后输入http(s)://你的网站地址/login
4. Telegram Bot 名称是bot username 去掉@后的字符串
## 模型支持
此版本额外支持以下模型:
1. 第三方模型 **gps** gpt-4-gizmo-*
2. 智谱glm-4vglm-4v识图
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
@@ -86,6 +98,12 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/5b3228e8-2556-44f7-97d6-4f8d8ee6effa)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## 相关项目
- [One API](https://github.com/songquanpeng/one-api):原版项目
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)Midjourney接口支持
- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代 AI 一站式 B/C 端解决方案
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)用key查询使用额度
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date)

View File

@@ -50,6 +50,7 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var LinuxDoOAuthEnabled = false
var WeChatAuthEnabled = false
var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false
@@ -82,6 +83,10 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var LinuxDoClientId = ""
var LinuxDoClientSecret = ""
var LinuxDoMinLevel = 0
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
@@ -186,10 +191,10 @@ const (
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeAPI2D = 2
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeCloseAI = 4
ChannelTypeOpenAISB = 5
ChannelTypeOllama = 4
ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
@@ -211,6 +216,7 @@ const (
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27
)
var ChannelBaseURLs = []string{
@@ -218,7 +224,7 @@ var ChannelBaseURLs = []string{
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
@@ -238,7 +244,8 @@ var ChannelBaseURLs = []string{
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
"https://api.perplexity.ai", //27
}

View File

@@ -2,5 +2,6 @@ package common
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
var SQLitePath = "one-api.db?_busy_timeout=5000"

View File

@@ -5,14 +5,14 @@ import (
"encoding/base64"
"errors"
"fmt"
"github.com/chai2010/webp"
"golang.org/x/image/webp"
"image"
"io"
"net/http"
"strings"
)
func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
// 去除base64数据的URL前缀如果有
if idx := strings.Index(base64String, ","); idx != -1 {
base64String = base64String[idx+1:]
@@ -22,13 +22,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
decodedData, err := base64.StdEncoding.DecodeString(base64String)
if err != nil {
fmt.Println("Error: Failed to decode base64 string")
return image.Config{}, "", err
return image.Config{}, "", "", err
}
// 创建一个bytes.Buffer用于存储解码后的数据
reader := bytes.NewReader(decodedData)
config, format, err := getImageConfig(reader)
return config, format, err
return config, format, base64String, err
}
func IsImageUrl(url string) (bool, error) {
@@ -42,6 +42,7 @@ func IsImageUrl(url string) (bool, error) {
return true, nil
}
// GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {

View File

@@ -13,7 +13,7 @@ import (
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
var DefaultModelRatio = map[string]float64{
//"midjourney": 50,
"gpt-4-gizmo-*": 15,
"gpt-4": 15,
@@ -61,8 +61,12 @@ var ModelRatio = map[string]float64{
"text-moderation-latest": 0.1,
"dall-e-2": 8,
"dall-e-3": 16,
"claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
@@ -91,17 +95,32 @@ var ModelRatio = map[string]float64{
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
}
var ModelPrice = map[string]float64{
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_describe": 0.05,
"mj_upscale": 0.05,
var DefaultModelPrice = map[string]float64{
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
}
var ModelPrice = map[string]float64{}
var ModelRatio = map[string]float64{}
func ModelPrice2JSONString() string {
if len(ModelPrice) == 0 {
ModelPrice = DefaultModelPrice
}
jsonBytes, err := json.Marshal(ModelPrice)
if err != nil {
SysError("error marshalling model price: " + err.Error())
@@ -115,6 +134,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
}
func GetModelPrice(name string, printErr bool) float64 {
if len(ModelPrice) == 0 {
ModelPrice = DefaultModelPrice
}
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
@@ -129,6 +151,9 @@ func GetModelPrice(name string, printErr bool) float64 {
}
func ModelRatio2JSONString() string {
if len(ModelRatio) == 0 {
ModelRatio = DefaultModelRatio
}
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
@@ -142,6 +167,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
}
func GetModelRatio(name string) float64 {
if len(ModelRatio) == 0 {
ModelRatio = DefaultModelRatio
}
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
@@ -179,10 +207,11 @@ func GetCompletionRatio(name string) float64 {
return 2
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3.38
}
if strings.HasPrefix(name, "claude-2") {
return 2.965517
return 3
} else if strings.HasPrefix(name, "claude-2") {
return 3
} else if strings.HasPrefix(name, "claude-3") {
return 5
}
return 1
}

44
constant/midjourney.go Normal file
View File

@@ -0,0 +1,44 @@
package constant
var MjNotifyEnabled = false
const (
MjErrorUnknown = 5
MjRequestError = 4
)
const (
MjActionImagine = "IMAGINE"
MjActionDescribe = "DESCRIBE"
MjActionBlend = "BLEND"
MjActionUpscale = "UPSCALE"
MjActionVariation = "VARIATION"
MjActionReRoll = "REROLL"
MjActionInPaint = "INPAINT"
MjActionModal = "MODAL"
MjActionZoom = "ZOOM"
MjActionCustomZoom = "CUSTOM_ZOOM"
MjActionShorten = "SHORTEN"
MjActionHighVariation = "HIGH_VARIATION"
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
)
var MidjourneyModel2Action = map[string]string{
"mj_imagine": MjActionImagine,
"mj_describe": MjActionDescribe,
"mj_blend": MjActionBlend,
"mj_upscale": MjActionUpscale,
"mj_variation": MjActionVariation,
"mj_reroll": MjActionReRoll,
"mj_modal": MjActionModal,
"mj_inpaint": MjActionInPaint,
"mj_zoom": MjActionZoom,
"mj_custom_zoom": MjActionCustomZoom,
"mj_shorten": MjActionShorten,
"mj_high_variation": MjActionHighVariation,
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
}

View File

@@ -214,10 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
case common.ChannelTypeCloseAI:
return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB:
return updateChannelOpenAISBBalance(channel)
//case common.ChannelTypeOpenAISB:
// return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:

View File

@@ -24,6 +24,9 @@ import (
)
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
}
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -37,6 +40,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
meta := relaycommon.GenRelayInfo(c)
apiType := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
@@ -45,13 +61,14 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
}
if testModel == "" {
testModel = adaptor.GetModelList()[0]
meta.UpstreamModelName = testModel
}
request := buildTestRequest()
request.Model = testModel
meta.UpstreamModelName = testModel
adaptor.Init(meta, *request)
request.Model = testModel
meta.UpstreamModelName = testModel
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
if err != nil {
return err, nil
@@ -68,11 +85,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
}
if resp.StatusCode != http.StatusOK {
err := relaycommon.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
}
if usage == nil {
return errors.New("usage is nil"), nil

View File

@@ -123,6 +123,8 @@ func GitHubOAuth(c *gin.Context) {
}
} else {
if common.RegisterEnabled {
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
@@ -133,7 +135,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil {
if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),

239
controller/linuxdo.go Normal file
View File

@@ -0,0 +1,239 @@
package controller
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
type LinuxDoOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type LinuxDoUser struct {
ID int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
form := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
}
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", "Basic "+auth)
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse LinuxDoOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res2.Body.Close()
var linuxdoUser LinuxDoUser
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
if err != nil {
return nil, err
}
if linuxdoUser.ID == 0 {
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
}
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
return nil, errors.New("用户 LINUX DO 信任等级不足!")
}
return &linuxdoUser, nil
}
func LinuxDoOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
err := user.FillUserByLinuxDoId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
affCode := c.Query("aff")
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
if linuxdoUser.Name != "" {
user.DisplayName = linuxdoUser.Name
} else {
user.DisplayName = linuxdoUser.Username
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 LINUX DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -10,9 +10,13 @@ import (
func GetAllLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 {
p = 0
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
@@ -20,7 +24,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -38,16 +42,23 @@ func GetAllLogs(c *gin.Context) {
func GetUserLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 {
p = 0
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
if pageSize > 100 {
pageSize = 100
}
userId := c.GetInt("id")
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -10,145 +10,14 @@ import (
"log"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/model"
relay2 "one-api/relay"
"one-api/service"
"strconv"
"strings"
"time"
)
/*func UpdateMidjourneyTask() {
//revocer
//imageModel := "midjourney"
ctx := context.TODO()
imageModel := "midjourney"
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
for _, task := range tasks {
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
if err != nil {
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
task.FailReason = fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", task.ChannelId)
task.Status = "FAILURE"
task.Progress = "100%"
err := task.Update()
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
continue
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
continue
}
responseBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
log.Printf("responseBody: %s", string(responseBody))
var responseItem Midjourney
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
var responseWithoutStatus MidjourneyWithoutStatus
var responseStatus MidjourneyStatus
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
err2 := json.Unmarshal(responseBody, &responseStatus)
if err1 == nil && err2 == nil {
jsonData, err3 := json.Marshal(responseWithoutStatus)
if err3 != nil {
log.Printf("UpdateMidjourneyTask error1: %v", err3)
continue
}
err4 := json.Unmarshal(jsonData, &responseStatus)
if err4 != nil {
log.Printf("UpdateMidjourneyTask error2: %v", err4)
continue
}
responseItem.Status = strconv.Itoa(responseStatus.Status)
} else {
log.Printf("UpdateMidjourneyTask error3: %v", err)
continue
}
} else {
log.Printf("UpdateMidjourneyTask error4: %v", err)
continue
}
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
log.Println("error update user quota cache: " + err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 {
err := model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
log.Println("fail to increase user quota")
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
log.Printf("UpdateMidjourneyTask error5: %v", err)
}
log.Printf("UpdateMidjourneyTask success: %v", task)
cancel()
}
}
}
}
*/
func UpdateMidjourneyTaskBulk() {
//imageModel := "midjourney"
ctx := context.TODO()
@@ -228,12 +97,16 @@ func UpdateMidjourneyTaskBulk() {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []relay2.Midjourney
var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
@@ -245,10 +118,16 @@ func UpdateMidjourneyTaskBulk() {
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
// 如果时间超过一小时且进度不是100%,则认为任务失败
if useTime > 3600000 && task.Progress != "100%" {
responseItem.FailReason = "上游任务超时超过1小时"
responseItem.Status = "FAILURE"
}
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
@@ -259,6 +138,15 @@ func UpdateMidjourneyTaskBulk() {
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if responseItem.Properties != nil {
propertiesStr, _ := json.Marshal(responseItem.Properties)
task.Properties = string(propertiesStr)
}
if responseItem.Buttons != nil {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
@@ -286,7 +174,7 @@ func UpdateMidjourneyTaskBulk() {
}
}
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
if oldTask.Code != 1 {
return true
}

View File

@@ -5,12 +5,29 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
func TestStatus(c *gin.Context) {
err := model.PingDB()
if err != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"success": false,
"message": "数据库连接失败",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Server is running",
})
return
}
func GetStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
@@ -20,6 +37,8 @@ func GetStatus(c *gin.Context) {
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
"linuxdo_client_id": common.LinuxDoClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
@@ -43,6 +62,8 @@ func GetStatus(c *gin.Context) {
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": common.PayAddress != "" && common.EpayId != "" && common.EpayKey != "",
"mj_notify_enabled": constant.MjNotifyEnabled,
"version": common.Version,
},
})
return

View File

@@ -4,12 +4,13 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay"
"one-api/relay/channel/ai360"
"one-api/relay/channel/moonshot"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -59,8 +60,8 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
for i := 0; i < constant.APITypeDummy; i++ {
if i == constant.APITypeAIProxyLibrary {
for i := 0; i < relayconstant.APITypeDummy; i++ {
if i == relayconstant.APITypeAIProxyLibrary {
continue
}
adaptor := relay.GetAdaptor(i)
@@ -100,6 +101,17 @@ func init() {
Parent: nil,
})
}
for modelName, _ := range constant.MidjourneyModel2Action {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "midjourney",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model

View File

@@ -50,6 +50,14 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "LinuxDoOAuthEnabled":
if option.Value == "true" && common.LinuxDoClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 LINUX DO OAuth请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret",
})
return
}
case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{

View File

@@ -12,7 +12,6 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"strconv"
"strings"
)
func Relay(c *gin.Context) {
@@ -38,83 +37,58 @@ func Relay(c *gin.Context) {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.StatusCode == http.StatusTooManyRequests {
//err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
//err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
"error": err.Error,
})
}
channelId := c.GetInt("channel_id")
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if service.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Message)
service.DisableChannel(channelId, channelName, err.Error.Message)
}
}
}
func RelayMidjourney(c *gin.Context) {
relayMode := relayconstant.RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = relayconstant.RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
relayMode = relayconstant.RelayModeMidjourneyBlend
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
relayMode = relayconstant.RelayModeMidjourneyDescribe
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
relayMode = relayconstant.RelayModeMidjourneyNotify
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
relayMode = relayconstant.RelayModeMidjourneyChange
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
relayMode = relayconstant.RelayModeMidjourneyChange
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
relayMode = relayconstant.RelayModeMidjourneyTaskFetch
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
}
relayMode := c.GetInt("relay_mode")
var err *dto.MidjourneyResponse
switch relayMode {
case relayconstant.RelayModeMidjourneyNotify:
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed:
err = relay.RelayMidjourneyTaskImageSeed(c)
case relayconstant.RelayModeSwapFace:
err = relay.RelaySwapFace(c)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
c.JSON(429, gin.H{
"error": fmt.Sprintf("%s %s", err.Description, err.Result),
"type": "upstream_error",
})
statusCode := http.StatusBadRequest
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests
}
c.JSON(statusCode, gin.H{
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
"type": "upstream_error",
"code": err.Code,
})
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
//if shouldDisableChannel(&err.OpenAIError) {
// channelId := c.GetInt("channel_id")
// channelName := c.GetString("channel_name")
// disableChannel(channelId, channelName, err.Result)
//};''''''''''''''''''''''''''''''''
}
}

View File

@@ -11,6 +11,7 @@ import (
"one-api/model"
"one-api/service"
"strconv"
"sync"
"time"
)
@@ -111,6 +112,33 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
}
// tradeNo lock
var orderLocks sync.Map
var createLock sync.Mutex
// LockOrder 尝试对给定订单号加锁
func LockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if !ok {
createLock.Lock()
defer createLock.Unlock()
lock, ok = orderLocks.Load(tradeNo)
if !ok {
lock = new(sync.Mutex)
orderLocks.Store(tradeNo, lock)
}
}
lock.(*sync.Mutex).Lock()
}
// UnlockOrder 释放给定订单号的锁
func UnlockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if ok {
lock.(*sync.Mutex).Unlock()
}
}
func EpayNotify(c *gin.Context) {
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
@@ -122,6 +150,7 @@ func EpayNotify(c *gin.Context) {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
return
}
}
verifyInfo, err := client.Verify(params)
@@ -135,11 +164,19 @@ func EpayNotify(c *gin.Context) {
if err != nil {
log.Println("易支付回调写入失败")
}
log.Println("易支付回调签名验证失败")
return
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp == nil {
log.Printf("易支付回调未找到订单: %v", verifyInfo)
return
}
if topUp.Status == "pending" {
topUp.Status = "success"
err := topUp.Update()

View File

@@ -65,6 +65,7 @@ func setupLogin(user *model.User, c *gin.Context) {
session.Set("username", user.Username)
session.Set("role", user.Role)
session.Set("status", user.Status)
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
err := session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -683,7 +684,7 @@ func ManageUser(c *gin.Context) {
})
return
}
if err := user.HardDelete(); err != nil {
if err := user.Delete(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),

View File

@@ -2,18 +2,17 @@ version: '3.4'
services:
new-api:
image: calciumion/new-api:latest
# build: .
image: pengzhile/new-api:latest
container_name: new-api
restart: always
command: --log-dir /app/logs
ports:
- "3000:3000"
volumes:
- ./data:/data
- ./data/new-api:/data
- ./logs:/app/logs
environment:
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai
@@ -23,13 +22,22 @@ services:
depends_on:
- redis
healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
timeout: 10s
retries: 3
- db
redis:
image: redis:latest
container_name: redis
restart: always
db:
image: mysql:8.2.0
container_name: mysql
restart: always
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: newapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: new-api # 自动创建数据库

View File

@@ -8,6 +8,47 @@ type OpenAIError struct {
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
Error OpenAIError `json:"error"`
StatusCode int `json:"status_code"`
}
type GeneralErrorResponse struct {
Error OpenAIError `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}

View File

@@ -1,7 +1,21 @@
package dto
//type SimpleMjRequest struct {
// Prompt string `json:"prompt"`
// CustomId string `json:"customId"`
// Action string `json:"action"`
// Content string `json:"content"`
//}
type SwapFaceRequest struct {
SourceBase64 string `json:"sourceBase64"`
TargetBase64 string `json:"targetBase64"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
@@ -9,6 +23,7 @@ type MidjourneyRequest struct {
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
MaskBase64 string `json:"maskBase64"`
}
type MidjourneyResponse struct {
@@ -17,3 +32,64 @@ type MidjourneyResponse struct {
Properties interface{} `json:"properties"`
Result string `json:"result"`
}
type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"`
Response MidjourneyResponse
}
type MidjourneyDto struct {
MjId string `json:"id"`
Action string `json:"action"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
Buttons any `json:"buttons"`
MaskBase64 string `json:"maskBase64"`
Properties *Properties `json:"properties"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}
type MidjourneyWithoutStatus struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
}
type ActionButton struct {
CustomId any `json:"customId"`
Emoji any `json:"emoji"`
Label any `json:"label"`
Type any `json:"type"`
Style any `json:"style"`
}
type Properties struct {
FinalPrompt string `json:"finalPrompt"`
FinalZhPrompt string `json:"finalZhPrompt"`
}

View File

@@ -14,6 +14,7 @@ type GeneralOpenAIRequest struct {
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
@@ -82,6 +83,14 @@ func (m Message) StringContent() string {
return string(m.Content)
}
func (m Message) IsStringContent() bool {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return true
}
return false
}
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
var stringContent string
@@ -130,9 +139,3 @@ func (m Message) ParseContent() []MediaMessage {
return nil
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

View File

@@ -61,3 +61,9 @@ type CompletionsStreamResponse struct {
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

14
go.mod
View File

@@ -4,13 +4,12 @@ module one-api
go 1.18
require (
github.com/chai2010/webp v1.1.1
github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.16.0
github.com/go-playground/validator/v10 v10.19.0
github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
@@ -19,7 +18,8 @@ require (
github.com/samber/lo v1.38.1
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.17.0
golang.org/x/crypto v0.21.0
golang.org/x/image v0.15.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
@@ -32,7 +32,7 @@ require (
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
@@ -50,7 +50,7 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
@@ -64,9 +64,9 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

29
go.sum
View File

@@ -3,8 +3,6 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -19,6 +17,8 @@ github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cn
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
@@ -37,6 +37,7 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -47,10 +48,10 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
@@ -79,8 +80,6 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
@@ -108,10 +107,10 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
@@ -176,15 +175,19 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -197,16 +200,14 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -15,6 +15,7 @@ func authHelper(c *gin.Context, minRole int) {
role := session.Get("role")
id := session.Get("id")
status := session.Get("status")
linuxDoEnable := session.Get("linuxdo_enable")
if username == nil {
// Check access token
accessToken := c.Request.Header.Get("Authorization")
@@ -33,6 +34,7 @@ func authHelper(c *gin.Context, minRole int) {
role = user.Role
id = user.Id
status = user.Status
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -50,6 +52,14 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort()
return
}
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户 LINUX DO 信任等级不足",
})
c.Abort()
return
}
if role.(int) < minRole {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -100,16 +110,25 @@ func TokenAuth() func(c *gin.Context) {
}
token, err := model.ValidateUserToken(key)
if err != nil {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !linuxDoEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
return
}
c.Set("id", token.UserId)
@@ -125,17 +144,11 @@ func TokenAuth() func(c *gin.Context) {
} else {
c.Set("token_model_limit_enabled", false)
}
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1])
} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
}
}

View File

@@ -4,7 +4,11 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relayconstant "one-api/relay/constant"
"one-api/service"
"strconv"
"strings"
@@ -23,32 +27,59 @@ func Distribute() func(c *gin.Context) {
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
if channel.Status != common.ChannelStatusEnabled {
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
return
}
} else {
shouldSelectChannel := true
// Select a channel for the user
var modelRequest ModelRequest
var err error
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
// Midjourney
if modelRequest.Model == "" {
modelRequest.Model = "midjourney"
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return
} else {
// task fetch, task fetch by condition, notify
shouldSelectChannel = false
}
}
modelRequest.Model = midjourneyModel
}
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
@@ -87,60 +118,64 @@ func Distribute() func(c *gin.Context) {
}
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
// token model limit is empty, all models are not allowed
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
}
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
// 如果错误,而且渠道为空,说明是没有可用渠道
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
return
}
if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
// 如果错误,而且渠道为空,说明是没有可用渠道
abortWithMessage(c, http.StatusServiceUnavailable, message)
return
}
if channel == nil {
abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
c.Next()
}

View File

@@ -5,7 +5,7 @@ import (
"one-api/common"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
@@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.Abort()
common.LogError(c.Request.Context(), message)
}
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
c.JSON(statusCode, gin.H{
"description": description,
"type": "new_api_error",
"code": code,
})
c.Abort()
common.LogError(c.Request.Context(), description)
}

View File

@@ -147,7 +147,12 @@ func FixAbility() (int, error) {
return 0, err
}
var channels []Channel
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
}
if err != nil {
return 0, err
}

View File

@@ -204,6 +204,30 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
if !common.RedisEnabled {
return IsLinuxDoEnabled(userId)
}
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if linuxDoEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set linuxdo enabled error: " + err.Error())
}
return linuxDoEnabled, err
}
var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex

View File

@@ -5,9 +5,11 @@ import (
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"log"
"one-api/common"
"os"
"strings"
"sync"
"time"
)
@@ -60,6 +62,7 @@ func chooseDB() (*gorm.DB, error) {
dsn += "?parseTime=true"
}
}
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -90,6 +93,9 @@ func InitDB() (err error) {
if !common.IsMasterNode {
return nil
}
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
}
common.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
@@ -148,3 +154,33 @@ func CloseDB() error {
err = sqlDB.Close()
return err
}
var (
lastPingTime time.Time
pingMutex sync.Mutex
)
func PingDB() error {
pingMutex.Lock()
defer pingMutex.Unlock()
if time.Since(lastPingTime) < time.Second*10 {
return nil
}
sqlDB, err := DB.DB()
if err != nil {
log.Printf("Error getting sql.DB from GORM: %v", err)
return err
}
err = sqlDB.Ping()
if err != nil {
log.Printf("Error pinging DB: %v", err)
return err
}
lastPingTime = time.Now()
common.SysLog("Database pinged successfully")
return nil
}

View File

@@ -4,21 +4,23 @@ type Midjourney struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
Action string `json:"action" gorm:"type:varchar(40);index"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
SubmitTime int64 `json:"submit_time" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
ImageUrl string `json:"image_url"`
Status string `json:"status"`
Progress string `json:"progress"`
Status string `json:"status" gorm:"type:varchar(20);index"`
Progress string `json:"progress" gorm:"type:varchar(30);index"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
Buttons string `json:"buttons"`
Properties string `json:"properties"`
}
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段

View File

@@ -2,6 +2,7 @@ package model
import (
"one-api/common"
"one-api/constant"
"strconv"
"strings"
"time"
@@ -30,6 +31,7 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
@@ -65,6 +67,9 @@ func InitOptionMap() {
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["LinuxDoClientId"] = ""
common.OptionMap["LinuxDoClientSecret"] = ""
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
common.OptionMap["TelegramBotToken"] = ""
common.OptionMap["TelegramBotName"] = ""
common.OptionMap["WeChatServerAddress"] = ""
@@ -88,6 +93,7 @@ func InitOptionMap() {
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -155,6 +161,8 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue
case "LinuxDoOAuthEnabled":
common.LinuxDoOAuthEnabled = boolValue
case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue
case "TelegramOAuthEnabled":
@@ -181,6 +189,8 @@ func updateOptionMap(key string, value string) (err error) {
common.DataExportEnabled = boolValue
case "DefaultCollapseSidebar":
common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled":
constant.MjNotifyEnabled = boolValue
}
}
switch key {
@@ -217,6 +227,12 @@ func updateOptionMap(key string, value string) (err error) {
common.GitHubClientId = value
case "GitHubClientSecret":
common.GitHubClientSecret = value
case "LinuxDoClientId":
common.LinuxDoClientId = value
case "LinuxDoClientSecret":
common.LinuxDoClientSecret = value
case "LinuxDoMinLevel":
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
case "Footer":
common.Footer = value
case "SystemName":

View File

@@ -8,16 +8,17 @@ import (
)
type Redemption struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Quota int `json:"quota" gorm:"default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request
UsedUserId int `json:"used_user_id"`
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Quota int `json:"quota" gorm:"default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request
UsedUserId int `json:"used_user_id"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) {

View File

@@ -10,19 +10,20 @@ import (
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {

View File

@@ -21,6 +21,8 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
@@ -272,6 +274,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByLinuxDoId() error {
if user.LinuxDoId == "" {
return errors.New("LINUX DO id 为空!")
}
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -311,6 +321,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
return DB.Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
@@ -356,6 +370,18 @@ func IsUserEnabled(userId int) (bool, error) {
return user.Status == common.UserStatusEnabled, nil
}
func IsLinuxDoEnabled(userId int) (bool, error) {
if userId == 0 {
return false, errors.New("user id is empty")
}
var user User
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
if err != nil {
return false, err
}
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
}
func ValidateAccessToken(token string) (user *User) {
if token == "" {
return nil

View File

@@ -1,25 +1,27 @@
package ali
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
Content string `json:"content"`
Role string `json:"role"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
Prompt string `json:"prompt,omitempty"`
//History []AliMessage `json:"history,omitempty"`
Messages []AliMessage `json:"messages"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Input AliInput `json:"input,omitempty"`
Parameters AliParameters `json:"parameters,omitempty"`
}

View File

@@ -14,41 +14,35 @@ import (
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
const EnableSearchModelSuffix = "-internet"
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
//prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.StringContent(),
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.StringContent()
break
}
messages = append(messages, AliMessage{
User: message.StringContent(),
Bot: string(request.Messages[i+1].Content),
})
i++
}
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
Prompt: prompt,
History: messages,
//Prompt: prompt,
Messages: messages,
},
Parameters: AliParameters{
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
EnableSearch: enableSearch,
},
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP,
// TopK: 50,
// //Seed: 0,
// //EnableSearch: false,
//},
}
}
@@ -77,7 +71,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
if aliResponse.Code != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
@@ -242,7 +236,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus
}
if aliResponse.Code != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,

View File

@@ -24,21 +24,10 @@ var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
return &BaiduChatRequest{
Messages: messages,
@@ -184,7 +173,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
}
if baiduResponse.ErrorMsg != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@@ -220,7 +209,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro
}
if baiduResponse.ErrorMsg != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",

View File

@@ -9,18 +9,32 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
} else {
a.RequestMode = RequestModeCompletion
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
if a.RequestMode == RequestModeMessage {
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -38,7 +52,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
if a.RequestMode == RequestModeCompletion {
return requestOpenAI2ClaudeComplete(*request), nil
} else {
return requestOpenAI2ClaudeMessage(*request)
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
@@ -47,11 +65,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = claudeStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
} else {
err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}

View File

@@ -1,7 +1,13 @@
package claude
var ModelList = []string{
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
"claude-instant-1.2",
"claude-2",
"claude-2.0",
"claude-2.1",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
}
var ChannelName = "claude"

View File

@@ -4,14 +4,36 @@ type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeMediaMessage struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
}
type ClaudeMessageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
}
type ClaudeMessage struct {
Role string `json:"role"`
Content any `json:"content"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample uint `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
System string `json:"system,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
@@ -22,8 +44,25 @@ type ClaudeError struct {
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
Id string `json:"id"`
Type string `json:"type"`
Content []ClaudeMediaMessage `json:"content"`
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
Usage ClaudeUsage `json:"usage"`
Index int `json:"index"` // stream only
Delta *ClaudeMediaMessage `json:"delta"` // stream only
Message *ClaudeResponse `json:"message"` // stream only: message_start
}
//type ClaudeResponseChoice struct {
// Index int `json:"index"`
// Type string `json:"type"`
//}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

View File

@@ -17,6 +17,8 @@ func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "end_turn":
return "stop"
case "max_tokens":
return "length"
default:
@@ -24,7 +26,7 @@ func stopReasonClaude2OpenAI(reason string) string {
}
}
func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
@@ -44,7 +46,9 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
if prompt == "" {
prompt = message.StringContent()
}
}
}
prompt += "\n\nAssistant:"
@@ -52,51 +56,154 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
return &claudeRequest
}
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
claudeMessages := make([]ClaudeMessage, 0)
for _, message := range textRequest.Messages {
if message.Role == "system" {
claudeRequest.System = message.StringContent()
} else {
claudeMessage := ClaudeMessage{
Role: message.Role,
}
if message.IsStringContent() {
claudeMessage.Content = message.StringContent()
} else {
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
for _, mediaMessage := range message.ParseContent() {
claudeMediaMessage := ClaudeMediaMessage{
Type: mediaMessage.Type,
}
if mediaMessage.Type == "text" {
claudeMediaMessage.Text = mediaMessage.Text
} else {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
claudeMediaMessage.Type = "image"
claudeMediaMessage.Source = &ClaudeMessageSource{
Type: "base64",
}
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data
} else {
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
if err != nil {
return nil, err
}
claudeMediaMessage.Source.MediaType = "image/" + format
claudeMediaMessage.Source.Data = base64String
}
}
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
}
claudeMessage.Content = claudeMediaMessages
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeRequest.Prompt = ""
claudeRequest.Messages = claudeMessages
return &claudeRequest, nil
}
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: content,
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
var response dto.ChatCompletionsStreamResponse
var claudeUsage *ClaudeUsage
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
} else {
if claudeResponse.Type == "message_start" {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
claudeUsage = &claudeResponse.Message.Usage
} else if claudeResponse.Type == "content_block_delta" {
choice.Index = claudeResponse.Index
choice.Delta.Content = claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
claudeUsage = &claudeResponse.Usage
}
}
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
}
func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
choices := make([]dto.OpenAITextResponseChoice, 0)
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []dto.OpenAITextResponseChoice{choice},
}
if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: content,
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choices = append(choices, choice)
} else {
fullTextResponse.Id = claudeResponse.Id
for i, message := range claudeResponse.Content {
content, _ := json.Marshal(message.Text)
choice := dto.OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choices = append(choices, choice)
}
}
fullTextResponse.Choices = choices
return &fullTextResponse
}
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage dto.Usage
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
@@ -108,10 +215,10 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
@@ -128,10 +235,31 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse)
response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
modelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else {
return true
}
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = modelName
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
@@ -146,12 +274,19 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, responseText
if requestMode == RequestModeCompletion {
usage = *service.ResponseText2Usage(responseText, modelName, promptTokens)
} else {
if usage.CompletionTokens == 0 {
usage = *service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
}
}
return nil, &usage
}
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -167,7 +302,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}
if claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
@@ -176,12 +311,17 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens := service.CountTokenText(claudeResponse.Completion, model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
usage := dto.Usage{}
if requestMode == RequestModeCompletion {
usage.PromptTokens = promptTokens
usage.CompletionTokens = completionTokens
usage.TotalTokens = promptTokens + completionTokens
} else {
usage.PromptTokens = claudeResponse.Usage.InputTokens
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)

View File

@@ -20,6 +20,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
version := "v1"
if info.ApiVersion != "" {
version = info.ApiVersion
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent"

View File

@@ -140,8 +140,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
},
FinishReason: relaycommon.StopFinishReason,
}
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
if len(candidate.Content.Parts) > 0 {
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
choice.Message.Content = content
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
@@ -246,7 +246,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}
if len(geminiResponse.Candidates) == 0 {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",

View File

@@ -0,0 +1,59 @@
package ollama
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/chat", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Ollama(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,5 @@
package ollama
var ModelList []string
var ChannelName = "ollama"

View File

@@ -0,0 +1,18 @@
package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Options *OllamaOptions `json:"options,omitempty"`
}
type OllamaOptions struct {
Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
}

View File

@@ -0,0 +1,32 @@
package ollama
import "one-api/dto"
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
})
}
str, ok := request.Stop.(string)
var Stop []string
if ok {
Stop = []string{str}
} else {
Stop, _ = request.Stop.([]string)
}
return &OllamaRequest{
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Options: &OllamaOptions{
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
},
}
}

View File

@@ -49,11 +49,14 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
req.Header.Set("api-key", info.ApiKey)
return nil
}
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
if info.ChannelType == common.ChannelTypeOpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
req.Header.Set("OpenAI-Organization", info.Organization)
}
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
//if info.ChannelType == common.ChannelTypeOpenRouter {
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
// req.Header.Set("X-Title", "One API")
//}
return nil
}

View File

@@ -127,8 +127,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}
if textResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body

View File

@@ -146,7 +146,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",

View File

@@ -0,0 +1,63 @@
package perplexity
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if request.TopP >= 1 {
request.TopP = 0.99
}
return requestOpenAI2Perplexity(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,7 @@
package perplexity
var ModelList = []string{
"sonar-small-chat", "sonar-small-online", "sonar-medium-chat", "sonar-medium-online", "mistral-7b-instruct", "mixtral-8x7b-instruct",
}
var ChannelName = "perplexity"

View File

@@ -0,0 +1,21 @@
package perplexity
import "one-api/dto"
func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
})
}
return &dto.GeneralOpenAIRequest{
Model: request.Model,
Stream: request.Stream,
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
}
}

View File

@@ -175,7 +175,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
}
if TencentResponse.Error.Code != 0 {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},

View File

@@ -24,8 +24,9 @@ import (
func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages))
shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5")
for _, message := range request.Messages {
if message.Role == "system" {
if message.Role == "system" && shouldCovertSystemMessage {
messages = append(messages, XunfeiMessage{
Role: "user",
Content: message.StringContent(),
@@ -126,7 +127,7 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
}
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
@@ -156,7 +157,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
}
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
@@ -235,20 +236,44 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
return dataChan, stopChan, nil
}
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
if apiVersion == "" {
apiVersion = "v1.1"
common.SysLog("api_version not found, use default: " + apiVersion)
}
domain := "general"
if apiVersion != "v1.1" {
domain += strings.Split(apiVersion, ".")[0]
func apiVersion2domain(apiVersion string) string {
switch apiVersion {
case "v1.1":
return "general"
case "v2.1":
return "generalv2"
case "v3.1":
return "generalv3"
case "v3.5":
return "generalv3.5"
}
return "general" + apiVersion
}
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
apiVersion := getAPIVersion(c, modelName)
domain := apiVersion2domain(apiVersion)
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl
}
func getAPIVersion(c *gin.Context, modelName string) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion != "" {
return apiVersion
}
parts := strings.Split(modelName, "-")
if len(parts) == 2 {
apiVersion = parts[1]
return apiVersion
}
apiVersion = c.GetString("api_version")
if apiVersion != "" {
return apiVersion
}
apiVersion = "v1.1"
common.SysLog("api_version not found, using default: " + apiVersion)
return apiVersion
}

View File

@@ -36,6 +36,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
if request == nil {
return nil, errors.New("request is nil")
}
if request.TopP >= 1 {
request.TopP = 0.99
}
return requestOpenAI2Zhipu(*request), nil
}

View File

@@ -244,7 +244,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
}
if !zhipuResponse.Success {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",

View File

@@ -34,6 +34,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
if request == nil {
return nil, errors.New("request is nil")
}
if request.TopP >= 1 {
request.TopP = 0.99
}
return requestOpenAI2Zhipu(*request), nil
}

View File

@@ -234,8 +234,8 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
}
if textResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body

View File

@@ -24,6 +24,7 @@ type RelayInfo struct {
ApiVersion string
PromptTokens int
ApiKey string
Organization string
BaseUrl string
}
@@ -52,12 +53,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
}
if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAzureAPIVersion(c)
info.ApiVersion = GetAPIVersion(c)
}
return info
}

View File

@@ -17,10 +17,10 @@ import (
var StopFinishReason = "stop"
func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: dto.OpenAIError{
Error: dto.OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
@@ -40,7 +40,7 @@ func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.Open
if err != nil {
return
}
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
OpenAIErrorWithStatusCode.Error = textResponse.Error
return
}
@@ -66,12 +66,3 @@ func GetAPIVersion(c *gin.Context) string {
}
return apiVersion
}
func GetAzureAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}

View File

@@ -16,6 +16,8 @@ const (
APITypeTencent
APITypeGemini
APITypeZhipu_v4
APITypeOllama
APITypePerplexity
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -43,6 +45,10 @@ func ChannelType2APIType(channelType int) int {
apiType = APITypeGemini
case common.ChannelTypeZhipu_v4:
apiType = APITypeZhipu_v4
case common.ChannelTypeOllama:
apiType = APITypeOllama
case common.ChannelTypePerplexity:
apiType = APITypePerplexity
}
return apiType
}

View File

@@ -17,10 +17,15 @@ const (
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskImageSeed
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
RelayModeMidjourneyAction
RelayModeMidjourneyModal
RelayModeMidjourneyShorten
RelayModeSwapFace
)
func Path2RelayMode(path string) int {
@@ -48,3 +53,39 @@ func Path2RelayMode(path string) int {
}
return relayMode
}
func Path2RelayModeMidjourney(path string) int {
relayMode := RelayModeUnknown
if strings.HasPrefix(path, "/mj/submit/action") {
// midjourney plus
relayMode = RelayModeMidjourneyAction
} else if strings.HasPrefix(path, "/mj/submit/modal") {
// midjourney plus
relayMode = RelayModeMidjourneyModal
} else if strings.HasPrefix(path, "/mj/submit/shorten") {
// midjourney plus
relayMode = RelayModeMidjourneyShorten
} else if strings.HasPrefix(path, "/mj/insight-face/swap") {
// midjourney plus
relayMode = RelayModeSwapFace
} else if strings.HasPrefix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasPrefix(path, "/mj/submit/blend") {
relayMode = RelayModeMidjourneyBlend
} else if strings.HasPrefix(path, "/mj/submit/describe") {
relayMode = RelayModeMidjourneyDescribe
} else if strings.HasPrefix(path, "/mj/notify") {
relayMode = RelayModeMidjourneyNotify
} else if strings.HasPrefix(path, "/mj/submit/change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasPrefix(path, "/mj/submit/simple-change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasSuffix(path, "/fetch") {
relayMode = RelayModeMidjourneyTaskFetch
} else if strings.HasSuffix(path, "/image-seed") {
relayMode = RelayModeMidjourneyTaskImageSeed
} else if strings.HasSuffix(path, "/list-by-condition") {
relayMode = RelayModeMidjourneyTaskFetchByCondition
}
return relayMode
}

View File

@@ -113,7 +113,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiVersion := relaycommon.GetAzureAPIVersion(c)
apiVersion := relaycommon.GetAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
}

View File

@@ -24,16 +24,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
startTime := time.Now()
var imageRequest dto.ImageRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if imageRequest.Model == "" {
@@ -136,7 +133,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
if userQuota-quota < 0 {
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
@@ -176,46 +173,42 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
var textResponse dto.ImageResponse
defer func(ctx context.Context) {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
if consumeQuota {
if resp.StatusCode != http.StatusOK {
return
}
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
if resp.StatusCode != http.StatusOK {
return
}
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}(c.Request.Context())
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])

View File

@@ -9,6 +9,7 @@ import (
"log"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relayconstant "one-api/relay/constant"
@@ -20,53 +21,6 @@ import (
"github.com/gin-gonic/gin"
)
type Midjourney struct {
MjId string `json:"id"`
Action string `json:"action"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}
type MidjourneyWithoutStatus struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
}
var DefaultModelPrice = map[string]float64{
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_describe": 0.05,
"mj_upscale": 0.05,
}
func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByOnlyMJId(taskId)
@@ -108,7 +62,7 @@ func RelayMidjourneyImage(c *gin.Context) {
}
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
var midjRequest Midjourney
var midjRequest dto.MidjourneyDto
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &dto.MidjourneyResponse{
@@ -147,7 +101,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
return nil
}
func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
@@ -167,9 +121,164 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
if originTask.Buttons != "" {
var buttons []dto.ActionButton
err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
if err == nil {
midjourneyTask.Buttons = buttons
}
}
if originTask.Properties != "" {
var properties dto.Properties
err := json.Unmarshal([]byte(originTask.Properties), &properties)
if err == nil {
midjourneyTask.Properties = &properties
}
}
return
}
func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
startTime := time.Now().UnixNano() / int64(time.Millisecond)
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
var swapFaceRequest dto.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
}
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
modelPrice := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if modelPrice == -1 {
defaultPrice, ok := common.DefaultModelPrice[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
groupRatio := common.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: err.Error(),
}
}
quota := int(ratio * common.QuotaPerUnit)
if userQuota-quota < 0 {
return &dto.MidjourneyResponse{
Code: 4,
Description: "quota_not_enough",
}
}
requestURL := c.Request.URL.String()
baseURL := c.GetString("base_url")
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
if err != nil {
return &mjResp.Response
}
defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: constant.MjActionSwapFace,
MjId: midjResponse.Result,
Prompt: "InsightFace",
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: startTime,
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
err = midjourneyTask.Insert()
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
}
c.Writer.WriteHeader(mjResp.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
}
return nil
}
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
}
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
}
if channel.Status != common.ChannelStatusEnabled {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
}
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
requestURL := c.Request.URL.String()
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
if err != nil {
return &midjResponseWithStatus.Response
}
midjResponse := &midjResponseWithStatus.Response
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
}
return nil
}
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
userId := c.GetInt("id")
var err error
@@ -184,7 +293,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
Description: "task_no_found",
}
}
midjourneyTask := getMidjourneyTaskModel(c, originTask)
midjourneyTask := coverMidjourneyTaskDto(c, originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &dto.MidjourneyResponse{
@@ -203,16 +312,16 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
Description: "do_request_failed",
}
}
var tasks []Midjourney
var tasks []dto.MidjourneyDto
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
midjourneyTask := getMidjourneyTaskModel(c, originTask)
midjourneyTask := coverMidjourneyTaskDto(c, originTask)
tasks = append(tasks, midjourneyTask)
}
}
if tasks == nil {
tasks = make([]Midjourney, 0)
tasks = make([]dto.MidjourneyDto, 0)
}
respBody, err = json.Marshal(tasks)
if err != nil {
@@ -235,170 +344,115 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
return nil
}
const (
// type 1 根据 mode 价格不同
MJSubmitActionImagine = "IMAGINE"
MJSubmitActionVariation = "VARIATION" //变换
MJSubmitActionBlend = "BLEND" //混图
MJSubmitActionReroll = "REROLL" //重新生成
// type 2 固定价格
MJSubmitActionDescribe = "DESCRIBE"
MJSubmitActionUpscale = "UPSCALE" // 放大
)
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
imageModel := "midjourney"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
//channelType := c.GetInt("channel")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
consumeQuota := true
var midjRequest dto.MidjourneyRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
}
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
}
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus需要从customId中获取任务信息
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
if mjErr != nil {
return mjErr
}
relayMode = relayconstant.RelayModeMidjourneyChange
}
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return &dto.MidjourneyResponse{
Code: 4,
Description: "prompt_is_required",
}
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
}
midjRequest.Action = "IMAGINE"
midjRequest.Action = constant.MjActionImagine
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = "DESCRIBE"
midjRequest.Action = constant.MjActionDescribe
} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务此类任务可重复plus only
midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = "BLEND"
midjRequest.Action = constant.MjActionBlend
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return &dto.MidjourneyResponse{
Code: 4,
Description: "taskId_is_required",
}
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
return &dto.MidjourneyResponse{
Code: 4,
Description: "action_is_required",
}
return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
} else if midjRequest.Index == 0 {
return &dto.MidjourneyResponse{
Code: 4,
Description: "index_can_only_be_1_2_3_4",
}
return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required")
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
return &dto.MidjourneyResponse{
Code: 4,
Description: "content_is_required",
}
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
}
params := convertSimpleChangeParams(midjRequest.Content)
params := service.ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "content_parse_failed",
}
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
}
mjId = params.ID
mjId = params.TaskId
midjRequest.Action = params.Action
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
//if midjRequest.MaskBase64 == "" {
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
//}
mjId = midjRequest.TaskId
midjRequest.Action = constant.MjActionModal
}
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
} else if originTask.Action == "UPSCALE" {
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
return &dto.MidjourneyResponse{
Code: 4,
Description: "upscale_task_can_not_be_change",
}
} else if originTask.Status != "SUCCESS" {
return &dto.MidjourneyResponse{
Code: 4,
Description: "task_status_is_not_success",
}
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
channel, err := model.GetChannelById(originTask.ChannelId, false)
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "channel_not_found",
}
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
}
if channel.Status != common.ChannelStatusEnabled {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
log.Printf("检测到此操作为放大、变换获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
log.Printf("检测到此操作为放大、变换、重绘获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
//if channelType == common.ChannelTypeMidjourneyPlus {
// // plus
//} else {
// // 普通版渠道
//
//}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
return &dto.MidjourneyResponse{
Code: 4,
Description: "unmarshal_model_mapping_failed",
}
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
consumeQuota = false
}
baseURL := common.ChannelBaseURLs[channelType]
//baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
baseURL := c.GetString("base_url")
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
log.Printf("fullRequestURL: %s", fullRequestURL)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(midjRequest)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "marshal_text_request_failed",
}
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
modelPrice := common.GetModelPrice(mjAction, true)
modelName := service.CoverActionToModelName(midjRequest.Action)
modelPrice := common.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if modelPrice == -1 {
defaultPrice, ok := DefaultModelPrice[mjAction]
defaultPrice, ok := common.DefaultModelPrice[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -423,53 +477,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "create_request_failed",
}
return &midjResponseWithStatus.Response
}
//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//mjToken := ""
//if c.Request.Header.Get("ApiKey") != "" {
// mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
//}
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
// print request header
log.Printf("request header: %s", req.Header)
log.Printf("request body: %s", midjRequest.Prompt)
resp, err := service.GetHttpClient().Do(req)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
err = req.Body.Close()
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "close_request_body_failed",
}
}
err = c.Request.Body.Close()
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "close_request_body_failed",
}
}
var midjResponse dto.MidjourneyResponse
midjResponse := &midjResponseWithStatus.Response
defer func(ctx context.Context) {
if consumeQuota {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
@@ -481,7 +496,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
@@ -489,41 +504,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}(c.Request.Context())
//if consumeQuota {
//
//}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "read_response_body_failed",
}
}
err = resp.Body.Close()
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
err = json.Unmarshal(responseBody, &midjResponse)
log.Printf("responseBody: %s", string(responseBody))
log.Printf("midjResponse: %v", midjResponse)
if resp.StatusCode != 200 {
return &dto.MidjourneyResponse{
Code: 4,
Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
}
}
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
// 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
@@ -575,8 +555,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
}
//修改返回值
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
responseBody = []byte(newBody)
if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
responseBody = []byte(newBody)
}
}
err = midjourneyTask.Insert()
@@ -593,21 +575,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody = []byte(newBody)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
//for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0])
//}
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
_, err = io.Copy(c.Writer, bodyReader)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
err = resp.Body.Close()
err = bodyReader.Close()
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
@@ -622,32 +605,3 @@ type taskChangeParams struct {
Action string
Index int
}
func convertSimpleChangeParams(content string) *taskChangeParams {
split := strings.Split(content, " ")
if len(split) != 2 {
return nil
}
action := strings.ToLower(split[1])
changeParams := &taskChangeParams{}
changeParams.ID = split[0]
if action[0] == 'u' {
changeParams.Action = "UPSCALE"
} else if action[0] == 'v' {
changeParams.Action = "VARIATION"
} else if action == "r" {
changeParams.Action = "REROLL"
return changeParams
} else {
return nil
}
index, err := strconv.Atoi(action[1:2])
if err != nil || index < 1 || index > 4 {
return nil
}
changeParams.Index = index
return changeParams
}

View File

@@ -2,6 +2,7 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -59,7 +60,6 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
}
}
relayInfo.IsStream = textRequest.Stream
relayInfo.UpstreamModelName = textRequest.Model
return textRequest, nil
}
@@ -85,9 +85,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
// set upstream model name
isModelMapped = true
}
}
relayInfo.UpstreamModelName = textRequest.Model
modelPrice := common.GetModelPrice(textRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
@@ -148,10 +150,19 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
}
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
return service.RelayErrorHandler(resp)
}
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
return openaiErr
}
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
@@ -169,6 +180,8 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil
case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
case relayconstant.RelayModeEmbeddings:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
default:
err = errors.New("unknown relay mode")
promptTokens = 0
@@ -216,6 +229,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return preConsumedQuota, userQuota, nil
}
func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsumedQuota int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error())
}
}(c)
}
}
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens

View File

@@ -6,8 +6,10 @@ import (
"one-api/relay/channel/baidu"
"one-api/relay/channel/claude"
"one-api/relay/channel/gemini"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/tencent"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
@@ -39,6 +41,10 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &zhipu.Adaptor{}
case constant.APITypeZhipu_v4:
return &zhipu_4v.Adaptor{}
case constant.APITypeOllama:
return &ollama.Adaptor{}
case constant.APITypePerplexity:
return &perplexity.Adaptor{}
}
return nil
}

View File

@@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.Use(middleware.GlobalAPIRateLimit())
{
apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout)
apiRouter.GET("/midjourney", controller.GetMidjourney)
@@ -22,6 +23,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxDoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)

View File

@@ -47,6 +47,9 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{
relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney)
relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
@@ -54,7 +57,9 @@ func SetRelayRouter(router *gin.Engine) {
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
relayMjRouter.POST("/notify", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
}
//relayMjRouter.Use()
}

View File

@@ -1,12 +1,30 @@
package service
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"strconv"
"strings"
)
func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
return &dto.MidjourneyResponse{
Code: code,
Description: desc,
}
}
func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
return &dto.MidjourneyResponseWithStatusCode{
StatusCode: statusCode,
Response: *MidjourneyErrorWrapper(code, desc),
}
}
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
@@ -23,7 +41,42 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError
Code: code,
}
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: openAIError,
StatusCode: statusCode,
Error: openAIError,
StatusCode: statusCode,
}
}
func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: dto.OpenAIError{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var errResponse dto.GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
if err != nil {
return
}
if errResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
errWithStatusCode.Error = errResponse.Error
} else {
errWithStatusCode.Error.Message = errResponse.ToMessage()
}
if errWithStatusCode.Error.Message == "" {
errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}

235
service/midjourney.go Normal file
View File

@@ -0,0 +1,235 @@
package service
import (
"context"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
"strconv"
"strings"
"time"
)
func CoverActionToModelName(mjAction string) string {
modelName := "mj_" + strings.ToLower(mjAction)
if mjAction == constant.MjActionSwapFace {
modelName = "swap_face"
}
return modelName
}
func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
action := ""
if relayMode == relayconstant.RelayModeMidjourneyAction {
// plus request
err := CoverPlusActionToNormalAction(midjRequest)
if err != nil {
return "", err, false
}
action = midjRequest.Action
} else {
switch relayMode {
case relayconstant.RelayModeMidjourneyImagine:
action = constant.MjActionImagine
case relayconstant.RelayModeMidjourneyDescribe:
action = constant.MjActionDescribe
case relayconstant.RelayModeMidjourneyBlend:
action = constant.MjActionBlend
case relayconstant.RelayModeMidjourneyShorten:
action = constant.MjActionShorten
case relayconstant.RelayModeMidjourneyChange:
action = midjRequest.Action
case relayconstant.RelayModeMidjourneyModal:
action = constant.MjActionModal
case relayconstant.RelayModeSwapFace:
action = constant.MjActionSwapFace
case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
}
action = params.Action
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
return "", nil, true
default:
return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
}
}
modelName := CoverActionToModelName(action)
return modelName, nil, true
}
func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
customId := midjRequest.CustomId
if customId == "" {
return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
}
splits := strings.Split(customId, "::")
var action string
if splits[1] == "JOB" {
action = splits[2]
} else {
action = splits[1]
}
if action == "" {
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
}
if strings.Contains(action, "upsample") {
index, err := strconv.Atoi(splits[3])
if err != nil {
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
}
midjRequest.Index = index
midjRequest.Action = constant.MjActionUpscale
} else if strings.Contains(action, "variation") {
midjRequest.Index = 1
if action == "variation" {
index, err := strconv.Atoi(splits[3])
if err != nil {
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
}
midjRequest.Index = index
midjRequest.Action = constant.MjActionVariation
} else if action == "low_variation" {
midjRequest.Action = constant.MjActionLowVariation
} else if action == "high_variation" {
midjRequest.Action = constant.MjActionHighVariation
}
} else if strings.Contains(action, "pan") {
midjRequest.Action = constant.MjActionPan
midjRequest.Index = 1
} else if strings.Contains(action, "reroll") {
midjRequest.Action = constant.MjActionReRoll
midjRequest.Index = 1
} else if action == "Outpaint" {
midjRequest.Action = constant.MjActionZoom
midjRequest.Index = 1
} else if action == "CustomZoom" {
midjRequest.Action = constant.MjActionCustomZoom
midjRequest.Index = 1
} else if action == "Inpaint" {
midjRequest.Action = constant.MjActionInPaint
midjRequest.Index = 1
} else {
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
}
return nil
}
func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
split := strings.Split(content, " ")
if len(split) != 2 {
return nil
}
action := strings.ToLower(split[1])
changeParams := &dto.MidjourneyRequest{}
changeParams.TaskId = split[0]
if action[0] == 'u' {
changeParams.Action = "UPSCALE"
} else if action[0] == 'v' {
changeParams.Action = "VARIATION"
} else if action == "r" {
changeParams.Action = "REROLL"
return changeParams
} else {
return nil
}
index, err := strconv.Atoi(action[1:2])
if err != nil || index < 1 || index > 4 {
return nil
}
changeParams.Index = index
return changeParams
}
func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
var nullBytes []byte
//var requestBody io.Reader
//requestBody = c.Request.Body
// read request body to json, delete accountFilter and notifyHook
var mapResult map[string]interface{}
// if get request, no need to read request body
if c.Request.Method != "GET" {
err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
delete(mapResult, "accountFilter")
if !constant.MjNotifyEnabled {
delete(mapResult, "notifyHook")
}
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
// make new request with mapResult
}
reqBody, err := json.Marshal(mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
defer cancel()
resp, err := GetHttpClient().Do(req)
if err != nil {
common.SysError("do request failed: " + err.Error())
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
}
statusCode := resp.StatusCode
//if statusCode != 200 {
// return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
//}
err = req.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
err = c.Request.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
var midjResponse dto.MidjourneyResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
}
err = resp.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
}
respStr := string(responseBody)
log.Printf("responseBody: %s", respStr)
if respStr == "" {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
} else {
err = json.Unmarshal(responseBody, &midjResponse)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
}
}
//log.Printf("midjResponse: %v", midjResponse)
//for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0])
//}
return &dto.MidjourneyResponseWithStatusCode{
StatusCode: statusCode,
Response: midjResponse,
}, responseBody, nil
}

View File

@@ -74,7 +74,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
} else {
common.SysLog(fmt.Sprintf("decoding image"))
config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
}
if err != nil {
return 0, err

4
web/.gitignore vendored
View File

@@ -21,6 +21,4 @@
npm-debug.log*
yarn-debug.log*
yarn-error.log*
.idea
package-lock.json
yarn.lock
.idea/

View File

@@ -49,8 +49,9 @@
]
},
"devDependencies": {
"prettier": "^2.7.1",
"typescript": "4.4.2"
"prettier": "2.8.8",
"typescript": "4.4.2",
"@babel/plugin-proposal-private-property-in-object": "^7.21.11"
},
"prettier": {
"singleQuote": true,

View File

@@ -8,12 +8,12 @@ import LoginForm from './components/LoginForm';
import NotFound from './pages/NotFound';
import Setting from './pages/Setting';
import EditUser from './pages/User/EditUser';
import { API, getLogo, getSystemName, showError, showNotice } from './helpers';
import { getLogo, getSystemName } from './helpers';
import PasswordResetForm from './components/PasswordResetForm';
import GitHubOAuth from './components/GitHubOAuth';
import LinuxDoOAuth from "./components/LinuxDoOAuth";
import PasswordResetConfirm from './components/PasswordResetConfirm';
import { UserContext } from './context/User';
import { StatusContext } from './context/Status';
import Channel from './pages/Channel';
import Token from './pages/Token';
import EditChannel from './pages/Channel/EditChannel';
@@ -21,12 +21,13 @@ import Redemption from './pages/Redemption';
import TopUp from './pages/TopUp';
import Log from './pages/Log';
import Chat from './pages/Chat';
import {Layout} from "@douyinfe/semi-ui";
import Midjourney from "./pages/Midjourney";
import Detail from "./pages/Detail";
import { Layout } from '@douyinfe/semi-ui';
import Midjourney from './pages/Midjourney';
import Detail from './pages/Detail';
const Home = lazy(() => import('./pages/Home'));
const About = lazy(() => import('./pages/About'));
function App() {
const [userState, userDispatch] = useContext(UserContext);
// const [statusState, statusDispatch] = useContext(StatusContext);
@@ -47,7 +48,7 @@ function App() {
}
let logo = getLogo();
if (logo) {
let linkElement = document.querySelector("link[rel~='icon']");
let linkElement = document.querySelector('link[rel~=\'icon\']');
if (linkElement) {
linkElement.href = logo;
}
@@ -56,185 +57,193 @@ function App() {
return (
<Layout>
<Layout.Content>
<Routes>
<Route
path='/'
element={
<Suspense fallback={<Loading></Loading>}>
<Home />
</Suspense>
}
/>
<Route
path='/channel'
element={
<PrivateRoute>
<Channel />
</PrivateRoute>
}
/>
<Route
path='/channel/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
</Suspense>
}
/>
<Route
path='/channel/add'
element={
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
</Suspense>
}
/>
<Route
path='/token'
element={
<PrivateRoute>
<Token />
</PrivateRoute>
}
/>
<Route
path='/redemption'
element={
<PrivateRoute>
<Redemption />
</PrivateRoute>
}
/>
<Route
path='/user'
element={
<PrivateRoute>
<User />
</PrivateRoute>
}
/>
<Route
path='/user/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}>
<EditUser />
</Suspense>
}
/>
<Route
path='/user/edit'
element={
<Suspense fallback={<Loading></Loading>}>
<EditUser />
</Suspense>
}
/>
<Route
path='/user/reset'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetConfirm />
</Suspense>
}
/>
<Route
path='/login'
element={
<Suspense fallback={<Loading></Loading>}>
<LoginForm />
</Suspense>
}
/>
<Route
path='/register'
element={
<Suspense fallback={<Loading></Loading>}>
<RegisterForm />
</Suspense>
}
/>
<Route
path='/reset'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetForm />
</Suspense>
}
/>
<Route
path='/oauth/github'
element={
<Suspense fallback={<Loading></Loading>}>
<GitHubOAuth />
</Suspense>
}
/>
<Route
path='/setting'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Setting />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/topup'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<TopUp />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/log'
element={
<PrivateRoute>
<Log />
</PrivateRoute>
}
/>
<Route
path='/detail'
element={
<PrivateRoute>
<Detail />
</PrivateRoute>
}
/>
<Route
path='/midjourney'
element={
<PrivateRoute>
<Midjourney />
</PrivateRoute>
}
/>
<Route
path='/about'
element={
<Suspense fallback={<Loading></Loading>}>
<About />
</Suspense>
}
/>
<Route
path='/chat'
element={
<Suspense fallback={<Loading></Loading>}>
<Chat />
</Suspense>
}
/>
<Route path='*' element={
<NotFound />
} />
</Routes>
</Layout.Content>
<Layout.Content>
<Routes>
<Route
path="/"
element={
<Suspense fallback={<Loading></Loading>}>
<Home />
</Suspense>
}
/>
<Route
path="/channel"
element={
<PrivateRoute>
<Channel />
</PrivateRoute>
}
/>
<Route
path="/channel/edit/:id"
element={
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
</Suspense>
}
/>
<Route
path="/channel/add"
element={
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
</Suspense>
}
/>
<Route
path="/token"
element={
<PrivateRoute>
<Token />
</PrivateRoute>
}
/>
<Route
path="/redemption"
element={
<PrivateRoute>
<Redemption />
</PrivateRoute>
}
/>
<Route
path="/user"
element={
<PrivateRoute>
<User />
</PrivateRoute>
}
/>
<Route
path="/user/edit/:id"
element={
<Suspense fallback={<Loading></Loading>}>
<EditUser />
</Suspense>
}
/>
<Route
path="/user/edit"
element={
<Suspense fallback={<Loading></Loading>}>
<EditUser />
</Suspense>
}
/>
<Route
path="/user/reset"
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetConfirm />
</Suspense>
}
/>
<Route
path="/login"
element={
<Suspense fallback={<Loading></Loading>}>
<LoginForm />
</Suspense>
}
/>
<Route
path="/register"
element={
<Suspense fallback={<Loading></Loading>}>
<RegisterForm />
</Suspense>
}
/>
<Route
path="/reset"
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetForm />
</Suspense>
}
/>
<Route
path="/oauth/github"
element={
<Suspense fallback={<Loading></Loading>}>
<GitHubOAuth />
</Suspense>
}
/>
<Route
path="/oauth/linuxdo"
element={
<Suspense fallback={<Loading></Loading>}>
<LinuxDoOAuth />
</Suspense>
}
/>
<Route
path="/setting"
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Setting />
</Suspense>
</PrivateRoute>
}
/>
<Route
path="/topup"
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<TopUp />
</Suspense>
</PrivateRoute>
}
/>
<Route
path="/log"
element={
<PrivateRoute>
<Log />
</PrivateRoute>
}
/>
<Route
path="/detail"
element={
<PrivateRoute>
<Detail />
</PrivateRoute>
}
/>
<Route
path="/midjourney"
element={
<PrivateRoute>
<Midjourney />
</PrivateRoute>
}
/>
<Route
path="/about"
element={
<Suspense fallback={<Loading></Loading>}>
<About />
</Suspense>
}
/>
<Route
path="/chat"
element={
<Suspense fallback={<Loading></Loading>}>
<Chat />
</Suspense>
}
/>
<Route path="*" element={
<NotFound />
} />
</Routes>
</Layout.Content>
</Layout>
);
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import { getFooterHTML, getSystemName } from '../helpers';
import {Layout} from "@douyinfe/semi-ui";
import { Layout } from '@douyinfe/semi-ui';
const Footer = () => {
const systemName = getSystemName();
@@ -29,30 +29,30 @@ const Footer = () => {
return (
<Layout>
<Layout.Content style={{textAlign: 'center'}}>
<Layout.Content style={{ textAlign: 'center' }}>
{footer ? (
<div
className='custom-footer'
className="custom-footer"
dangerouslySetInnerHTML={{ __html: footer }}
></div>
) : (
<div className='custom-footer'>
<div className="custom-footer">
<a
href='https://github.com/Calcium-Ion/new-api'
target='_blank'
href="https://github.com/Calcium-Ion/new-api"
target="_blank" rel="noreferrer"
>
New API {process.env.REACT_APP_VERSION}{' '}
</a>
{' '}
<a href='https://github.com/Calcium-Ion' target='_blank'>
<a href="https://github.com/Calcium-Ion" target="_blank" rel="noreferrer">
Calcium-Ion
</a>{' '}
开发基于{' '}
<a href='https://github.com/songquanpeng/one-api' target='_blank'>
<a href="https://github.com/songquanpeng/one-api" target="_blank" rel="noreferrer">
One API v0.5.4
</a>{' '}
本项目根据{' '}
<a href='https://opensource.org/licenses/mit-license.php'>
<a href="https://opensource.org/licenses/mit-license.php">
MIT 许可证
</a>{' '}
授权

View File

@@ -14,7 +14,8 @@ const GitHubOAuth = () => {
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
let aff = localStorage.getItem('aff');
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}&aff=${aff}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
@@ -41,6 +42,14 @@ const GitHubOAuth = () => {
};
useEffect(() => {
let error = searchParams.get('error');
if (error) {
let errorDescription = searchParams.get('error_description');
showError(`授权错误:${error}: ${errorDescription}`);
navigate('/setting');
return;
}
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();
@@ -49,7 +58,7 @@ const GitHubOAuth = () => {
return (
<Segment style={{ minHeight: '300px' }}>
<Dimmer active inverted>
<Loader size='large'>{prompt}</Loader>
<Loader size="large">{prompt}</Loader>
</Dimmer>
</Segment>
);

View File

@@ -1,165 +1,161 @@
import React, {useContext, useEffect, useRef, useState} from 'react';
import {Link, useNavigate} from 'react-router-dom';
import {UserContext} from '../context/User';
import React, { useContext, useEffect, useState } from 'react';
import { Link, useNavigate } from 'react-router-dom';
import { UserContext } from '../context/User';
import {API, getLogo, getSystemName, isAdmin, isMobile, showSuccess} from '../helpers';
import { API, getLogo, getSystemName, showSuccess } from '../helpers';
import '../index.css';
import fireworks from 'react-fireworks';
import {
IconKey,
IconUser,
IconHelpCircle
} from '@douyinfe/semi-icons';
import {Nav, Avatar, Dropdown, Layout, Switch} from '@douyinfe/semi-ui';
import {stringToColor} from "../helpers/render";
import { IconHelpCircle, IconKey, IconUser } from '@douyinfe/semi-icons';
import { Avatar, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui';
import { stringToColor } from '../helpers/render';
// HeaderBar Buttons
let headerButtons = [
{
text: '关于',
itemKey: 'about',
to: '/about',
icon: <IconHelpCircle/>
},
{
text: '关于',
itemKey: 'about',
to: '/about',
icon: <IconHelpCircle />
}
];
if (localStorage.getItem('chat_link')) {
headerButtons.splice(1, 0, {
name: '聊天',
to: '/chat',
icon: 'comments'
});
headerButtons.splice(1, 0, {
name: '聊天',
to: '/chat',
icon: 'comments'
});
}
const HeaderBar = () => {
const [userState, userDispatch] = useContext(UserContext);
let navigate = useNavigate();
const [userState, userDispatch] = useContext(UserContext);
let navigate = useNavigate();
const [showSidebar, setShowSidebar] = useState(false);
const [dark, setDark] = useState(false);
const systemName = getSystemName();
const logo = getLogo();
var themeMode = localStorage.getItem('theme-mode');
const currentDate = new Date();
// enable fireworks on new year(1.1 and 2.9-2.24)
const isNewYear = (currentDate.getMonth() === 0 && currentDate.getDate() === 1) || (currentDate.getMonth() === 1 && currentDate.getDate() >= 9 && currentDate.getDate() <= 24);
const [showSidebar, setShowSidebar] = useState(false);
const [dark, setDark] = useState(false);
const systemName = getSystemName();
const logo = getLogo();
var themeMode = localStorage.getItem('theme-mode');
const currentDate = new Date();
// enable fireworks on new year(1.1 and 2.9-2.24)
const isNewYear = (currentDate.getMonth() === 0 && currentDate.getDate() === 1) || (currentDate.getMonth() === 1 && currentDate.getDate() >= 9 && currentDate.getDate() <= 24);
async function logout() {
setShowSidebar(false);
await API.get('/api/user/logout');
showSuccess('注销成功!');
userDispatch({type: 'logout'});
localStorage.removeItem('user');
navigate('/login');
async function logout() {
setShowSidebar(false);
await API.get('/api/user/logout');
showSuccess('注销成功!');
userDispatch({ type: 'logout' });
localStorage.removeItem('user');
navigate('/login');
}
const handleNewYearClick = () => {
fireworks.init('root', {});
fireworks.start();
setTimeout(() => {
fireworks.stop();
setTimeout(() => {
window.location.reload();
}, 10000);
}, 3000);
};
useEffect(() => {
if (themeMode === 'dark') {
switchMode(true);
}
if (isNewYear) {
console.log('Happy New Year!');
}
}, []);
const handleNewYearClick = () => {
fireworks.init("root",{});
fireworks.start();
setTimeout(() => {
fireworks.stop();
setTimeout(() => {
window.location.reload();
}, 10000);
}, 3000);
};
const switchMode = (model) => {
const body = document.body;
if (!model) {
body.removeAttribute('theme-mode');
localStorage.setItem('theme-mode', 'light');
} else {
body.setAttribute('theme-mode', 'dark');
localStorage.setItem('theme-mode', 'dark');
}
setDark(model);
};
return (
<>
<Layout>
<div style={{ width: '100%' }}>
<Nav
mode={'horizontal'}
// bodyStyle={{ height: 100 }}
renderWrapper={({ itemElement, isSubNav, isInSubNav, props }) => {
const routerMap = {
about: '/about',
login: '/login',
register: '/register'
};
return (
<Link
style={{ textDecoration: 'none' }}
to={routerMap[props.itemKey]}
>
{itemElement}
</Link>
);
}}
selectedKeys={[]}
// items={headerButtons}
onSelect={key => {
useEffect(() => {
if (themeMode === 'dark') {
switchMode(true);
}
if (isNewYear) {
console.log('Happy New Year!');
}
}, []);
const switchMode = (model) => {
const body = document.body;
if (!model) {
body.removeAttribute('theme-mode');
localStorage.setItem('theme-mode', 'light');
} else {
body.setAttribute('theme-mode', 'dark');
localStorage.setItem('theme-mode', 'dark');
}
setDark(model);
};
return (
<>
<Layout>
<div style={{width: '100%'}}>
<Nav
mode={'horizontal'}
// bodyStyle={{ height: 100 }}
renderWrapper={({itemElement, isSubNav, isInSubNav, props}) => {
const routerMap = {
about: "/about",
login: "/login",
register: "/register",
};
return (
<Link
style={{textDecoration: "none"}}
to={routerMap[props.itemKey]}
>
{itemElement}
</Link>
);
}}
selectedKeys={[]}
// items={headerButtons}
onSelect={key => {
}}
footer={
<>
{isNewYear &&
// happy new year
<Dropdown
position="bottomRight"
render={
<Dropdown.Menu>
<Dropdown.Item onClick={handleNewYearClick}>Happy New Year!!!</Dropdown.Item>
</Dropdown.Menu>
}
>
<Nav.Item itemKey={'new-year'} text={'🏮'}/>
</Dropdown>
}
<Nav.Item itemKey={'about'} icon={<IconHelpCircle />} />
<Switch checkedText="🌞" size={'large'} checked={dark} uncheckedText="🌙" onChange={switchMode} />
{userState.user ?
<>
<Dropdown
position="bottomRight"
render={
<Dropdown.Menu>
<Dropdown.Item onClick={logout}>退出</Dropdown.Item>
</Dropdown.Menu>
}
>
<Avatar size="small" color={stringToColor(userState.user.username)} style={{ margin: 4 }}>
{userState.user.username[0]}
</Avatar>
<span>{userState.user.username}</span>
</Dropdown>
</>
:
<>
<Nav.Item itemKey={'login'} text={'登录'} icon={<IconKey />} />
<Nav.Item itemKey={'register'} text={'注册'} icon={<IconUser />} />
</>
}
</>
}
}}
footer={
<>
{isNewYear &&
// happy new year
<Dropdown
position="bottomRight"
render={
<Dropdown.Menu>
<Dropdown.Item onClick={handleNewYearClick}>Happy New Year!!!</Dropdown.Item>
</Dropdown.Menu>
}
>
<Nav.Item itemKey={'new-year'} text={'🏮'} />
</Dropdown>
}
<Nav.Item itemKey={'about'} icon={<IconHelpCircle />} />
<Switch checkedText="🌞" size={'large'} checked={dark} uncheckedText="🌙" onChange={switchMode} />
{userState.user ?
<>
<Dropdown
position="bottomRight"
render={
<Dropdown.Menu>
<Dropdown.Item onClick={logout}>退出</Dropdown.Item>
</Dropdown.Menu>
}
>
</Nav>
</div>
</Layout>
</>
);
<Avatar size="small" color={stringToColor(userState.user.username)} style={{ margin: 4 }}>
{userState.user.username[0]}
</Avatar>
<span>{userState.user.username}</span>
</Dropdown>
</>
:
<>
<Nav.Item itemKey={'login'} text={'登录'} icon={<IconKey />} />
<Nav.Item itemKey={'register'} text={'注册'} icon={<IconUser />} />
</>
}
</>
}
>
</Nav>
</div>
</Layout>
</>
);
};
export default HeaderBar;

View File

@@ -0,0 +1,21 @@
import React from 'react';
import {Icon} from '@douyinfe/semi-ui';
const LinuxDoIcon = (props) => {
function CustomIcon() {
return <svg className='icon' viewBox='0 0 24 24' version='1.1'
xmlns='http://www.w3.org/2000/svg' width='16' height='16' {...props}>
<path
d="M19.7,17.6c-0.1-0.2-0.2-0.4-0.2-0.6c0-0.4-0.2-0.7-0.5-1c-0.1-0.1-0.3-0.2-0.4-0.2c0.6-1.8-0.3-3.6-1.3-4.9c0,0,0,0,0,0c-0.8-1.2-2-2.1-1.9-3.7c0-1.9,0.2-5.4-3.3-5.1C8.5,2.3,9.5,6,9.4,7.3c0,1.1-0.5,2.2-1.3,3.1c-0.2,0.2-0.4,0.5-0.5,0.7c-1,1.2-1.5,2.8-1.5,4.3c-0.2,0.2-0.4,0.4-0.5,0.6c-0.1,0.1-0.2,0.2-0.2,0.3c-0.1,0.1-0.3,0.2-0.5,0.3c-0.4,0.1-0.7,0.3-0.9,0.7c-0.1,0.3-0.2,0.7-0.1,1.1c0.1,0.2,0.1,0.4,0,0.7c-0.2,0.4-0.2,0.9,0,1.4c0.3,0.4,0.8,0.5,1.5,0.6c0.5,0,1.1,0.2,1.6,0.4l0,0c0.5,0.3,1.1,0.5,1.7,0.5c0.3,0,0.7-0.1,1-0.2c0.3-0.2,0.5-0.4,0.6-0.7c0.4,0,1-0.2,1.7-0.2c0.6,0,1.2,0.2,2,0.1c0,0.1,0,0.2,0.1,0.3c0.2,0.5,0.7,0.9,1.3,1c0.1,0,0.1,0,0.2,0c0.8-0.1,1.6-0.5,2.1-1.1l0,0c0.4-0.4,0.9-0.7,1.4-0.9c0.6-0.3,1-0.5,1.1-1C20.3,18.6,20.1,18.2,19.7,17.6z M12.8,4.8c0.6,0.1,1.1,0.6,1,1.2c0,0.3-0.1,0.6-0.3,0.9c0,0,0,0-0.1,0c-0.2-0.1-0.3-0.1-0.4-0.2c0.1-0.1,0.1-0.3,0.2-0.5c0-0.4-0.2-0.7-0.4-0.7c-0.3,0-0.5,0.3-0.5,0.7c0,0,0,0.1,0,0.1c-0.1-0.1-0.3-0.1-0.4-0.2c0,0,0-0.1,0-0.1C11.8,5.5,12.2,4.9,12.8,4.8z M12.5,6.8c0.1,0.1,0.3,0.2,0.4,0.2c0.1,0,0.3,0.1,0.4,0.2c0.2,0.1,0.4,0.2,0.4,0.5c0,0.3-0.3,0.6-0.9,0.8c-0.2,0.1-0.3,0.1-0.4,0.2c-0.3,0.2-0.6,0.3-1,0.3c-0.3,0-0.6-0.2-0.8-0.4c-0.1-0.1-0.2-0.2-0.4-0.3C10.1,8.2,9.9,8,9.8,7.7c0-0.1,0.1-0.2,0.2-0.3c0.3-0.2,0.4-0.3,0.5-0.4l0.1-0.1c0.2-0.3,0.6-0.5,1-0.5C11.9,6.5,12.2,6.6,12.5,6.8z M10.4,5c0.4,0,0.7,0.4,0.8,1.1c0,0.1,0,0.1,0,0.2c-0.1,0-0.3,0.1-0.4,0.2c0,0,0-0.1,0-0.2c0-0.3-0.2-0.6-0.4-0.5c-0.2,0-0.3,0.3-0.3,0.6c0,0.2,0.1,0.3,0.2,0.4l0,0c0,0-0.1,0.1-0.2,0.1C9.9,6.7,9.7,6.4,9.7,6.1C9.7,5.5,10,5,10.4,5z M9.4,21.1c-0.7,0.3-1.6,0.2-2.2-0.2c-0.6-0.3-1.1-0.4-1.8-0.4c-0.5-0.1-1-0.1-1.1-0.3c-0.1-0.2-0.1-0.5,0.1-1c0.1-0.3,0.1-0.6,0-0.9c-0.1-0.3-0.1-0.5,0-0.8C4.5,17.2,4.7,17.1,5,17c0.3-0.1,0.5-0.2,0.7-0.4c0.1-0.1,0.2-0.2,0.3-0.4c0.3-0.4,0.5-0.6,0.8-0.6c0.6,0.1,1.1,1,1.5,1.9c0.2,0.3,0.4,0.7,0.7,1c0.4,0.5,0.9,1.2,0.9,1.6C9.9,20.6,9.7,20.9,9.4,21.1z M14.3,18.9c0,0.1,0,0.1-0.1,0.2c-1.2,0.9-2.8,1-4.1,0.3c-0.2-0.3-0.4-0.6-0.6-0.9c0.9-0.1,0.7-1.3-1.2-2.5c-2-1.3-0.6-3.7,0.1-4.8c0.1-0.1,0.1,0-0.3,0.8c-0.3,0.6-0.9,2.1-0.1,3.2c0-0.8,0.2-1.6,0.5-2.4c0.7-1.3,1.2-2.8,1.5-4.3c0.1,0.1,0.1,0.1,0.2,0.1c0.1,0.1,0.2,0.2,0.3,0.2c0.2,0.3,0.6,0.4,0.9,0.4c0,0,0.1,0,0.1,0c0.4,0,0.8-0.1,1.1-0.4c0.1-0.1,0.2-0.2,0.4-0.2c0.3-0.1,0.6-0.3,0.9-0.6c0.4,1.3,0.8,2.5,1.4,3.6c0.4,0.8,0.7,1.6,0.9,2.5c0.3,0,0.7,0.1,1,0.3c0.8,0.4,1.1,0.7,1,1.2c-0.1,0-0.1,0-0.2,0c0-0.3-0.2-0.6-0.9-0.9c-0.7-0.3-1.3-0.3-1.5,0.4c-0.1,0-0.2,0.1-0.3,0.1c-0.8,0.4-0.8,1.5-0.9,2.6C14.5,18.2,14.4,18.5,14.3,18.9z M18.9,19.5c-0.6,0.2-1.1,0.6-1.5,1.1c-0.4,0.6-1.1,1-1.9,0.9c-0.4,0-0.8-0.3-0.9-0.7c-0.1-0.6-0.1-1.2,0.2-1.8c0.1-0.4,0.2-0.7,0.3-1.1c0.1-1.2,0.1-1.9,0.6-2.2h0c0,0.5,0.3,0.8,0.7,1c0.5,0,1-0.1,1.4-0.5c0.1,0,0.1,0,0.2,0c0.3,0,0.5,0,0.7,0.2c0.2,0.2,0.3,0.5,0.3,0.7c0,0.3,0.2,0.6,0.3,0.9c0.5,0.5,0.5,0.8,0.5,0.9C19.7,19.1,19.3,19.3,18.9,19.5z M9.9,7.5c-0.1,0-0.1,0-0.1,0.1c0,0,0,0.1,0.1,0.1c0,0,0,0,0,0c0.1,0,0.1,0.1,0.1,0.1c0.3,0.4,0.8,0.6,1.4,0.7c0.5-0.1,1-0.2,1.5-0.6c0.2-0.1,0.4-0.2,0.6-0.3c0.1,0,0.1-0.1,0.1-0.1c0-0.1,0-0.1-0.1-0.1l0,0c-0.2,0.1-0.5,0.2-0.7,0.3c-0.4,0.3-0.9,0.5-1.4,0.5c-0.5,0-0.9-0.3-1.2-0.6C10.1,7.6,10,7.5,9.9,7.5z"
fill="currentColor"/>
</svg>;
}
return (
<div>
<Icon svg={<CustomIcon/>}/>
</div>
);
};
export default LinuxDoIcon;

View File

@@ -0,0 +1,67 @@
import React, { useContext, useEffect, useState } from 'react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers';
import { UserContext } from '../context/User';
const LinuxDoOAuth = () => {
const [searchParams, setSearchParams] = useSearchParams();
const [userState, userDispatch] = useContext(UserContext);
const [prompt, setPrompt] = useState('处理中...');
const [processing, setProcessing] = useState(true);
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
let aff = localStorage.getItem('aff');
const res = await API.get(`/api/oauth/linuxdo?code=${code}&state=${state}&aff=${aff}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/setting');
} else {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/');
}
} else {
showError(message);
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
navigate('/setting'); // in case this is failed to bind GitHub
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, state, count);
}
};
useEffect(() => {
let error = searchParams.get('error');
if (error) {
let errorDescription = searchParams.get('error_description');
showError(`授权错误:${error}: ${errorDescription}`);
navigate('/setting');
return;
}
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (
<Segment style={{ minHeight: '300px' }}>
<Dimmer active inverted>
<Loader size='large'>{prompt}</Loader>
</Dimmer>
</Segment>
);
};
export default LinuxDoOAuth;

View File

@@ -1,5 +1,5 @@
import React from 'react';
import { Segment, Dimmer, Loader } from 'semantic-ui-react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
const Loading = ({ prompt: name = 'page' }) => {
return (

View File

@@ -1,259 +1,265 @@
import React, { useContext, useEffect, useState } from 'react';
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getLogo, isMobile, showError, showInfo, showSuccess, showWarning } from '../helpers';
import { onGitHubOAuthClicked } from './utils';
import Turnstile from "react-turnstile";
import { Layout, Card, Image, Form, Button, Divider, Modal } from "@douyinfe/semi-ui";
import Title from "@douyinfe/semi-ui/lib/es/typography/title";
import Text from "@douyinfe/semi-ui/lib/es/typography/text";
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
import { onGitHubOAuthClicked, onLinuxDoOAuthClicked } from './utils';
import Turnstile from 'react-turnstile';
import { Button, Card, Divider, Form, Icon, Layout, Modal } from '@douyinfe/semi-ui';
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
import TelegramLoginButton from 'react-telegram-login';
import { IconGithubLogo } from '@douyinfe/semi-icons';
import LinuxDoIcon from './LinuxDoIcon';
import WeChatIcon from './WeChatIcon';
const LoginForm = () => {
const [inputs, setInputs] = useState({
username: '',
password: '',
wechat_verification_code: ''
});
const [searchParams, setSearchParams] = useSearchParams();
const [submitted, setSubmitted] = useState(false);
const { username, password } = inputs;
const [userState, userDispatch] = useContext(UserContext);
const [turnstileEnabled, setTurnstileEnabled] = useState(false);
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
const [turnstileToken, setTurnstileToken] = useState('');
let navigate = useNavigate();
const [status, setStatus] = useState({});
const logo = getLogo();
const [inputs, setInputs] = useState({
username: '',
password: '',
wechat_verification_code: ''
});
const [searchParams, setSearchParams] = useSearchParams();
const [submitted, setSubmitted] = useState(false);
const { username, password } = inputs;
const [userState, userDispatch] = useContext(UserContext);
const [turnstileEnabled, setTurnstileEnabled] = useState(false);
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
const [turnstileToken, setTurnstileToken] = useState('');
let navigate = useNavigate();
const [status, setStatus] = useState({});
const logo = getLogo();
useEffect(() => {
if (searchParams.get('expired')) {
showError('未登录或登录已过期,请重新登录!');
}
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
setStatus(status);
if (status.turnstile_check) {
setTurnstileEnabled(true);
setTurnstileSiteKey(status.turnstile_site_key);
}
}
}, []);
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
const onWeChatLoginClicked = () => {
setShowWeChatLoginModal(true);
};
const onSubmitWeChatVerificationCode = async () => {
if (turnstileEnabled && turnstileToken === '') {
showInfo('请稍后几秒重试Turnstile 正在检查用户环境!');
return;
}
const res = await API.get(
`/api/oauth/wechat?code=${inputs.wechat_verification_code}`
);
const { success, message, data } = res.data;
if (success) {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
navigate('/');
showSuccess('登录成功!');
setShowWeChatLoginModal(false);
} else {
showError(message);
}
};
function handleChange(name, value) {
setInputs((inputs) => ({ ...inputs, [name]: value }));
useEffect(() => {
if (searchParams.get('expired')) {
showError('未登录或登录已过期,请重新登录!');
}
async function handleSubmit(e) {
if (turnstileEnabled && turnstileToken === '') {
showInfo('请稍后几秒重试Turnstile 正在检查用户环境!');
return;
}
setSubmitted(true);
if (username && password) {
const res = await API.post(`/api/user/login?turnstile=${turnstileToken}`, {
username,
password
});
const { success, message, data } = res.data;
if (success) {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
if (username === 'root' && password === '123456') {
Modal.error({ title: '您正在使用默认密码!', content: '请立刻修改默认密码!', centered: true });
}
navigate('/token');
} else {
showError(message);
}
} else {
showError('请输入用户名和密码!');
}
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
setStatus(status);
if (status.turnstile_check) {
setTurnstileEnabled(true);
setTurnstileSiteKey(status.turnstile_site_key);
}
}
}, []);
// 添加Telegram登录处理函数
const onTelegramLoginClicked = async (response) => {
const fields = ["id", "first_name", "last_name", "username", "photo_url", "auth_date", "hash", "lang"];
const params = {};
fields.forEach((field) => {
if (response[field]) {
params[field] = response[field];
}
});
const res = await API.get(`/api/oauth/telegram/login`, { params });
const { success, message, data } = res.data;
if (success) {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/');
} else {
showError(message);
}
};
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
return (
<div>
<Layout>
<Layout.Header>
</Layout.Header>
<Layout.Content>
<div style={{ justifyContent: 'center', display: "flex", marginTop: 120 }}>
<div style={{ width: 500 }}>
<Card>
<Title heading={2} style={{ textAlign: 'center' }}>
用户登录
</Title>
<Form>
<Form.Input
field={'username'}
label={'用户名'}
placeholder='用户名'
name='username'
onChange={(value) => handleChange('username', value)}
/>
<Form.Input
field={'password'}
label={'密码'}
placeholder='密码'
name='password'
type='password'
onChange={(value) => handleChange('password', value)}
/>
const onWeChatLoginClicked = () => {
setShowWeChatLoginModal(true);
};
<Button theme='solid' style={{ width: '100%' }} type={'primary'} size='large'
htmlType={'submit'} onClick={handleSubmit}>
登录
</Button>
</Form>
<div style={{ display: 'flex', justifyContent: 'space-between', marginTop: 20 }}>
<Text>
没有账号请先 <Link to='/register'>注册账号</Link>
</Text>
<Text>
忘记密码 <Link to='/reset'>点击重置</Link>
</Text>
</div>
{status.github_oauth || status.wechat_login || status.telegram_oauth ? (
<>
<Divider margin='12px' align='center'>
第三方登录
</Divider>
<div style={{ display: 'flex', justifyContent: 'center', marginTop: 20 }}>
{status.github_oauth ? (
<Button
type='primary'
icon={<IconGithubLogo />}
onClick={() => onGitHubOAuthClicked(status.github_client_id)}
/>
) : (
<></>
)}
{/*{status.wechat_login ? (*/}
{/* <Button*/}
{/* circular*/}
{/* color='green'*/}
{/* icon='wechat'*/}
{/* onClick={onWeChatLoginClicked}*/}
{/* />*/}
{/*) : (*/}
{/* <></>*/}
{/*)}*/}
{status.telegram_oauth ? (
<TelegramLoginButton dataOnauth={onTelegramLoginClicked} botName={status.telegram_bot_name} />
) : (
<></>
)}
</div>
</>
) : (
<></>
)}
{/*<Modal*/}
{/* onClose={() => setShowWeChatLoginModal(false)}*/}
{/* onOpen={() => setShowWeChatLoginModal(true)}*/}
{/* open={showWeChatLoginModal}*/}
{/* size={'mini'}*/}
{/*>*/}
{/* <Modal.Content>*/}
{/* <Modal.Description>*/}
{/* <Image src={status.wechat_qrcode} fluid/>*/}
{/* <div style={{textAlign: 'center'}}>*/}
{/* <p>*/}
{/* 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)*/}
{/* </p>*/}
{/* </div>*/}
{/* <Form size='large'>*/}
{/* <Form.Input*/}
{/* field={'wechat_verification_code'}*/}
{/* placeholder='验证码'*/}
{/* name='wechat_verification_code'*/}
{/* value={inputs.wechat_verification_code}*/}
{/* onChange={handleChange}*/}
{/* />*/}
{/* <Button*/}
{/* color=''*/}
{/* fluid*/}
{/* size='large'*/}
{/* onClick={onSubmitWeChatVerificationCode}*/}
{/* >*/}
{/* 登录*/}
{/* </Button>*/}
{/* </Form>*/}
{/* </Modal.Description>*/}
{/* </Modal.Content>*/}
{/*</Modal>*/}
</Card>
{turnstileEnabled ? (
<div style={{ display: 'flex', justifyContent: 'center', marginTop: 20 }}>
<Turnstile
sitekey={turnstileSiteKey}
onVerify={(token) => {
setTurnstileToken(token);
}}
/>
</div>
) : (
<></>
)}
</div>
</div>
</Layout.Content>
</Layout>
</div>
const onSubmitWeChatVerificationCode = async () => {
if (turnstileEnabled && turnstileToken === '') {
showInfo('请稍后几秒重试Turnstile 正在检查用户环境!');
return;
}
const res = await API.get(
`/api/oauth/wechat?code=${inputs.wechat_verification_code}`
);
const { success, message, data } = res.data;
if (success) {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
navigate('/');
showSuccess('登录成功!');
setShowWeChatLoginModal(false);
} else {
showError(message);
}
};
function handleChange(name, value) {
setInputs((inputs) => ({ ...inputs, [name]: value }));
}
async function handleSubmit(e) {
if (turnstileEnabled && turnstileToken === '') {
showInfo('请稍后几秒重试Turnstile 正在检查用户环境!');
return;
}
setSubmitted(true);
if (username && password) {
const res = await API.post(`/api/user/login?turnstile=${turnstileToken}`, {
username,
password
});
const { success, message, data } = res.data;
if (success) {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
if (username === 'root' && password === '123456') {
Modal.error({ title: '您正在使用默认密码!', content: '请立刻修改默认密码!', centered: true });
}
navigate('/token');
} else {
showError(message);
}
} else {
showError('请输入用户名和密码!');
}
}
// 添加Telegram登录处理函数
const onTelegramLoginClicked = async (response) => {
const fields = ['id', 'first_name', 'last_name', 'username', 'photo_url', 'auth_date', 'hash', 'lang'];
const params = {};
fields.forEach((field) => {
if (response[field]) {
params[field] = response[field];
}
});
const res = await API.get(`/api/oauth/telegram/login`, { params });
const { success, message, data } = res.data;
if (success) {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/');
} else {
showError(message);
}
};
return (
<div>
<Layout>
<Layout.Header>
</Layout.Header>
<Layout.Content>
<div style={{ justifyContent: 'center', display: 'flex', marginTop: 120 }}>
<div style={{ width: 500 }}>
<Card>
<Title heading={2} style={{ textAlign: 'center' }}>
用户登录
</Title>
<Form>
<Form.Input
field={'username'}
label={'用户名'}
placeholder="用户名"
name="username"
onChange={(value) => handleChange('username', value)}
/>
<Form.Input
field={'password'}
label={'密码'}
placeholder="密码"
name="password"
type="password"
onChange={(value) => handleChange('password', value)}
/>
<Button theme="solid" style={{ width: '100%' }} type={'primary'} size="large"
htmlType={'submit'} onClick={handleSubmit}>
登录
</Button>
</Form>
<div style={{ display: 'flex', justifyContent: 'space-between', marginTop: 20 }}>
<Text>
没有账号请先 <Link to="/register">注册账号</Link>
</Text>
<Text>
忘记密码 <Link to="/reset">点击重置</Link>
</Text>
</div>
{status.github_oauth || status.linuxdo_oauth || status.wechat_login || status.telegram_oauth ? (
<>
<Divider margin="12px" align="center">
第三方登录
</Divider>
<div style={{ display: 'flex', justifyContent: 'center', marginTop: 20 }}>
{status.github_oauth ? (
<Button
type="primary"
icon={<IconGithubLogo />}
onClick={() => onGitHubOAuthClicked(status.github_client_id)}
/>
) : (
<></>
)}
{status.linuxdo_oauth ? (
<Button
type="primary"
icon={<LinuxDoIcon />}
style={{color: '#000'}}
onClick={() => onLinuxDoOAuthClicked(status.linuxdo_client_id)}
/>
) : (
<></>
)}
{status.wechat_login ? (
<Button
type="primary"
style={{ color: 'rgba(var(--semi-green-5), 1)' }}
icon={<Icon svg={<WeChatIcon />} />}
onClick={onWeChatLoginClicked}
/>
) : (
<></>
)}
{status.telegram_oauth ? (
<TelegramLoginButton dataOnauth={onTelegramLoginClicked} botName={status.telegram_bot_name} />
) : (
<></>
)}
</div>
</>
) : (
<></>
)}
<Modal
title="微信扫码登录"
visible={showWeChatLoginModal}
maskClosable={true}
onOk={onSubmitWeChatVerificationCode}
onCancel={() => setShowWeChatLoginModal(false)}
okText={'登录'}
size={'small'}
centered={true}
>
<div style={{ display: 'flex', alignItem: 'center', flexDirection: 'column' }}>
<img src={status.wechat_qrcode} />
</div>
<div style={{ textAlign: 'center' }}>
<p>
微信扫码关注公众号输入验证码获取验证码三分钟内有效
</p>
</div>
<Form size="large">
<Form.Input
field={'wechat_verification_code'}
placeholder="验证码"
label={'验证码'}
value={inputs.wechat_verification_code}
onChange={(value) => handleChange('wechat_verification_code', value)}
/>
</Form>
</Modal>
</Card>
{turnstileEnabled ? (
<div style={{ display: 'flex', justifyContent: 'center', marginTop: 20 }}>
<Turnstile
sitekey={turnstileSiteKey}
onVerify={(token) => {
setTurnstileToken(token);
}}
/>
</div>
) : (
<></>
)}
</div>
</div>
</Layout.Content>
</Layout>
</div>
);
};
export default LoginForm;

View File

@@ -1,501 +1,399 @@
import React, {useEffect, useState} from 'react';
import {Label} from 'semantic-ui-react';
import {API, copy, isAdmin, showError, showSuccess, timestamp2string} from '../helpers';
import React, { useEffect, useState } from 'react';
import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers';
import {Table, Avatar, Tag, Form, Button, Layout, Select, Popover, Modal, Spin, Space} from '@douyinfe/semi-ui';
import {ITEMS_PER_PAGE} from '../constants';
import {renderNumber, renderQuota, stringToColor} from '../helpers/render';
import {
IconAt,
IconHistogram,
IconGift,
IconKey,
IconUser,
IconLayers,
IconSetting,
IconCreditCard,
IconSemiLogo,
IconHome,
IconMore
} from '@douyinfe/semi-icons';
import Paragraph from "@douyinfe/semi-ui/lib/es/typography/paragraph";
import { Avatar, Button, Form, Layout, Modal, Select, Space, Spin, Table, Tag } from '@douyinfe/semi-ui';
import { ITEMS_PER_PAGE } from '../constants';
import { renderNumber, renderQuota, stringToColor } from '../helpers/render';
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
const { Header } = Layout;
const {Header} = Layout;
function renderTimestamp(timestamp) {
return (
<>
{timestamp2string(timestamp)}
</>
);
return (<>
{timestamp2string(timestamp)}
</>);
}
const MODE_OPTIONS = [
{key: 'all', text: '全部用户', value: 'all'},
{key: 'self', text: '当前用户', value: 'self'}
];
const MODE_OPTIONS = [{ key: 'all', text: '全部用户', value: 'all' }, { key: 'self', text: '当前用户', value: 'self' }];
const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
'light-blue', 'lime', 'orange', 'pink',
'purple', 'red', 'teal', 'violet', 'yellow'
]
const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', 'light-blue', 'lime', 'orange', 'pink', 'purple', 'red', 'teal', 'violet', 'yellow'];
function renderType(type) {
switch (type) {
case 1:
return <Tag color='cyan' size='large'> 充值 </Tag>;
case 2:
return <Tag color='lime' size='large'> 消费 </Tag>;
case 3:
return <Tag color='orange' size='large'> 管理 </Tag>;
case 4:
return <Tag color='purple' size='large'> 系统 </Tag>;
default:
return <Tag color='black' size='large'> 未知 </Tag>;
}
switch (type) {
case 1:
return <Tag color="cyan" size="large"> 充值 </Tag>;
case 2:
return <Tag color="lime" size="large"> 消费 </Tag>;
case 3:
return <Tag color="orange" size="large"> 管理 </Tag>;
case 4:
return <Tag color="purple" size="large"> 系统 </Tag>;
default:
return <Tag color="black" size="large"> 未知 </Tag>;
}
}
function renderIsStream(bool) {
if (bool) {
return <Tag color='blue' size='large'></Tag>;
} else {
return <Tag color='purple' size='large'>非流</Tag>;
}
if (bool) {
return <Tag color="blue" size="large"></Tag>;
} else {
return <Tag color="purple" size="large">非流</Tag>;
}
}
function renderUseTime(type) {
const time = parseInt(type);
if (time < 101) {
return <Tag color='green' size='large'> {time} s </Tag>;
} else if (time < 300) {
return <Tag color='orange' size='large'> {time} s </Tag>;
} else {
return <Tag color='red' size='large'> {time} s </Tag>;
}
const time = parseInt(type);
if (time < 101) {
return <Tag color="green" size="large"> {time} s </Tag>;
} else if (time < 300) {
return <Tag color="orange" size="large"> {time} s </Tag>;
} else {
return <Tag color="red" size="large"> {time} s </Tag>;
}
}
const LogsTable = () => {
const columns = [
{
title: '时间',
dataIndex: 'timestamp2string',
},
{
title: '渠道',
dataIndex: 'channel',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
isAdminUser ?
record.type === 0 || record.type === 2 ?
<div>
{<Tag color={colors[parseInt(text) % colors.length]} size='large'> {text} </Tag>}
</div>
:
<></>
:
<></>
);
},
},
{
title: '用户',
dataIndex: 'username',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
isAdminUser ?
<div>
<Avatar size="small" color={stringToColor(text)} style={{marginRight: 4}}
onClick={() => showUserInfo(record.user_id)}>
{typeof text === 'string' && text.slice(0, 1)}
</Avatar>
{text}
</div>
:
<></>
);
},
},
{
title: '令牌',
dataIndex: 'token_name',
render: (text, record, index) => {
return (
record.type === 0 || record.type === 2 ?
<div>
<Tag color='grey' size='large' onClick={() => {
copyText(text)
}}> {text} </Tag>
</div>
:
<></>
);
},
},
{
title: '类型',
dataIndex: 'type',
render: (text, record, index) => {
return (
<div>
{renderType(text)}
</div>
);
},
},
{
title: '模型',
dataIndex: 'model_name',
render: (text, record, index) => {
return (
record.type === 0 || record.type === 2 ?
<div>
<Tag color={stringToColor(text)} size='large' onClick={() => {
copyText(text)
}}> {text} </Tag>
</div>
:
<></>
);
},
},
{
title: '用时',
dataIndex: 'use_time',
render: (text, record, index) => {
return (
<div>
<Space>
{renderUseTime(text)}
{renderIsStream(record.is_stream)}
</Space>
</div>
);
},
},
{
title: '提示',
dataIndex: 'prompt_tokens',
render: (text, record, index) => {
return (
record.type === 0 || record.type === 2 ?
<div>
{<span> {text} </span>}
</div>
:
<></>
);
},
},
{
title: '补全',
dataIndex: 'completion_tokens',
render: (text, record, index) => {
return (
parseInt(text) > 0 && (record.type === 0 || record.type === 2) ?
<div>
{<span> {text} </span>}
</div>
:
<></>
);
},
},
{
title: '花费',
dataIndex: 'quota',
render: (text, record, index) => {
return (
record.type === 0 || record.type === 2 ?
<div>
{
renderQuota(text, 6)
}
</div>
:
<></>
);
}
},
{
title: '详情',
dataIndex: 'content',
render: (text, record, index) => {
return <Paragraph ellipsis={{ rows: 2, showTooltip: { type: 'popover', opts: { style: { width: 240 } } } }} style={{ maxWidth: 240}}>
{text}
</Paragraph>
}
}
];
const [logs, setLogs] = useState([]);
const [showStat, setShowStat] = useState(false);
const [loading, setLoading] = useState(false);
const [loadingStat, setLoadingStat] = useState(false);
const [activePage, setActivePage] = useState(1);
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [logType, setLogType] = useState(0);
const isAdminUser = isAdmin();
let now = new Date();
// 初始化start_timestamp为前一天
const [inputs, setInputs] = useState({
username: '',
token_name: '',
model_name: '',
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: ''
});
const {username, token_name, model_name, start_timestamp, end_timestamp, channel} = inputs;
const [stat, setStat] = useState({
quota: 0,
token: 0
});
const handleInputChange = (value, name) => {
setInputs((inputs) => ({...inputs, [name]: value}));
};
const getLogSelfStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
const {success, message, data} = res.data;
if (success) {
setStat(data);
} else {
showError(message);
}
};
const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
const {success, message, data} = res.data;
if (success) {
setStat(data);
} else {
showError(message);
}
};
const handleEyeClick = async () => {
setLoadingStat(true);
if (isAdminUser) {
await getLogStat();
} else {
await getLogSelfStat();
}
setShowStat(true);
setLoadingStat(false);
};
const showUserInfo = async (userId) => {
if (!isAdminUser) {
return;
}
const res = await API.get(`/api/user/${userId}`);
const {success, message, data} = res.data;
if (success) {
Modal.info({
title: '用户信息',
content: <div style={{padding: 12}}>
<p>用户名: {data.username}</p>
<p>余额: {renderQuota(data.quota)}</p>
<p>已用额度{renderQuota(data.used_quota)}</p>
<p>请求次数{renderNumber(data.request_count)}</p>
</div>,
centered: true,
})
} else {
showError(message);
}
};
const setLogsFormat = (logs) => {
for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
logs[i].key = '' + logs[i].id;
}
// data.key = '' + data.id
setLogs(logs);
setLogCount(logs.length + ITEMS_PER_PAGE);
// console.log(logCount);
const columns = [{
title: '时间', dataIndex: 'timestamp2string'
}, {
title: '渠道',
dataIndex: 'channel',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (isAdminUser ? record.type === 0 || record.type === 2 ? <div>
{<Tag color={colors[parseInt(text) % colors.length]} size="large"> {text} </Tag>}
</div> : <></> : <></>);
}
const loadLogs = async (startIdx) => {
setLoading(true);
let url = '';
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
if (isAdminUser) {
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
} else {
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
const res = await API.get(url);
const {success, message, data} = res.data;
if (success) {
if (startIdx === 0) {
setLogsFormat(data);
} else {
let newLogs = [...logs];
newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
setLogsFormat(newLogs);
}
} else {
showError(message);
}
setLoading(false);
};
const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE);
const handlePageChange = page => {
setActivePage(page);
if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
loadLogs(page - 1).then(r => {
});
}
};
const refresh = async () => {
// setLoading(true);
setActivePage(1);
await loadLogs(0);
};
const copyText = async (text) => {
if (await copy(text)) {
showSuccess('已复制:' + text);
} else {
// setSearchKeyword(text);
Modal.error({title: '无法复制到剪贴板,请手动复制', content: text});
}
}, {
title: '用户',
dataIndex: 'username',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (isAdminUser ? <div>
<Avatar size="small" color={stringToColor(text)} style={{ marginRight: 4 }}
onClick={() => showUserInfo(record.user_id)}>
{typeof text === 'string' && text.slice(0, 1)}
</Avatar>
{text}
</div> : <></>);
}
}, {
title: '令牌', dataIndex: 'token_name', render: (text, record, index) => {
return (record.type === 0 || record.type === 2 ? <div>
<Tag color="grey" size="large" onClick={() => {
copyText(text);
}}> {text} </Tag>
</div> : <></>);
}
}, {
title: '类型', dataIndex: 'type', render: (text, record, index) => {
return (<div>
{renderType(text)}
</div>);
}
}, {
title: '模型', dataIndex: 'model_name', render: (text, record, index) => {
return (record.type === 0 || record.type === 2 ? <div>
<Tag color={stringToColor(text)} size="large" onClick={() => {
copyText(text);
}}> {text} </Tag>
</div> : <></>);
}
}, {
title: '用时', dataIndex: 'use_time', render: (text, record, index) => {
return (<div>
<Space>
{renderUseTime(text)}
{renderIsStream(record.is_stream)}
</Space>
</div>);
}
}, {
title: '提示', dataIndex: 'prompt_tokens', render: (text, record, index) => {
return (record.type === 0 || record.type === 2 ? <div>
{<span> {text} </span>}
</div> : <></>);
}
}, {
title: '补全', dataIndex: 'completion_tokens', render: (text, record, index) => {
return (parseInt(text) > 0 && (record.type === 0 || record.type === 2) ? <div>
{<span> {text} </span>}
</div> : <></>);
}
}, {
title: '花费', dataIndex: 'quota', render: (text, record, index) => {
return (record.type === 0 || record.type === 2 ? <div>
{renderQuota(text, 6)}
</div> : <></>);
}
}, {
title: '详情', dataIndex: 'content', render: (text, record, index) => {
return <Paragraph ellipsis={{ rows: 2, showTooltip: { type: 'popover', opts: { style: { width: 240 } } } }}
style={{ maxWidth: 240 }}>
{text}
</Paragraph>;
}
}];
useEffect(() => {
refresh().then();
}, [logType]);
const [logs, setLogs] = useState([]);
const [showStat, setShowStat] = useState(false);
const [loading, setLoading] = useState(false);
const [loadingStat, setLoadingStat] = useState(false);
const [activePage, setActivePage] = useState(1);
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [logType, setLogType] = useState(0);
const isAdminUser = isAdmin();
let now = new Date();
// 初始化start_timestamp为前一天
const [inputs, setInputs] = useState({
username: '',
token_name: '',
model_name: '',
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: ''
});
const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
const searchLogs = async () => {
if (searchKeyword === '') {
// if keyword is blank, load files instead.
await loadLogs(0);
setActivePage(1);
return;
}
setSearching(true);
const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`);
const {success, message, data} = res.data;
if (success) {
setLogs(data);
setActivePage(1);
} else {
showError(message);
}
setSearching(false);
};
const [stat, setStat] = useState({
quota: 0, token: 0
});
const handleKeywordChange = async (e, {value}) => {
setSearchKeyword(value.trim());
};
const handleInputChange = (value, name) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
const sortLog = (key) => {
if (logs.length === 0) return;
setLoading(true);
let sortedLogs = [...logs];
if (typeof sortedLogs[0][key] === 'string') {
sortedLogs.sort((a, b) => {
return ('' + a[key]).localeCompare(b[key]);
});
} else {
sortedLogs.sort((a, b) => {
if (a[key] === b[key]) return 0;
if (a[key] > b[key]) return -1;
if (a[key] < b[key]) return 1;
});
}
if (sortedLogs[0].id === logs[0].id) {
sortedLogs.reverse();
}
setLogs(sortedLogs);
setLoading(false);
};
const getLogSelfStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
const { success, message, data } = res.data;
if (success) {
setStat(data);
} else {
showError(message);
}
};
return (
const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
const { success, message, data } = res.data;
if (success) {
setStat(data);
} else {
showError(message);
}
};
const handleEyeClick = async () => {
setLoadingStat(true);
if (isAdminUser) {
await getLogStat();
} else {
await getLogSelfStat();
}
setShowStat(true);
setLoadingStat(false);
};
const showUserInfo = async (userId) => {
if (!isAdminUser) {
return;
}
const res = await API.get(`/api/user/${userId}`);
const { success, message, data } = res.data;
if (success) {
Modal.info({
title: '用户信息', content: <div style={{ padding: 12 }}>
<p>用户名: {data.username}</p>
<p>余额: {renderQuota(data.quota)}</p>
<p>已用额度{renderQuota(data.used_quota)}</p>
<p>请求次数{renderNumber(data.request_count)}</p>
</div>, centered: true
});
} else {
showError(message);
}
};
const setLogsFormat = (logs) => {
for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
logs[i].key = '' + logs[i].id;
}
// data.key = '' + data.id
setLogs(logs);
setLogCount(logs.length + ITEMS_PER_PAGE);
// console.log(logCount);
};
const loadLogs = async (startIdx, pageSize, logType = 0) => {
setLoading(true);
let url = '';
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
if (isAdminUser) {
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
} else {
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
const res = await API.get(url);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
setLogsFormat(data);
} else {
let newLogs = [...logs];
newLogs.splice(startIdx * pageSize, data.length, ...data);
setLogsFormat(newLogs);
}
} else {
showError(message);
}
setLoading(false);
};
const pageData = logs.slice((activePage - 1) * pageSize, activePage * pageSize);
const handlePageChange = page => {
setActivePage(page);
if (page === Math.ceil(logs.length / pageSize) + 1) {
// In this case we have to load more data and then append them.
loadLogs(page - 1, pageSize, logType).then(r => {
});
}
};
const handlePageSizeChange = async (size) => {
localStorage.setItem('page-size', size + '');
setPageSize(size);
setActivePage(1);
loadLogs(0, size)
.then()
.catch((reason) => {
showError(reason);
});
};
const refresh = async (localLogType) => {
// setLoading(true);
setActivePage(1);
await loadLogs(0, pageSize, localLogType);
};
const copyText = async (text) => {
if (await copy(text)) {
showSuccess('已复制:' + text);
} else {
// setSearchKeyword(text);
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
}
};
useEffect(() => {
// console.log('default effect')
const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE;
setPageSize(localPageSize);
loadLogs(0, localPageSize)
.then()
.catch((reason) => {
showError(reason);
});
}, []);
const searchLogs = async () => {
if (searchKeyword === '') {
// if keyword is blank, load files instead.
await loadLogs(0, pageSize);
setActivePage(1);
return;
}
setSearching(true);
const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`);
const { success, message, data } = res.data;
if (success) {
setLogs(data);
setActivePage(1);
} else {
showError(message);
}
setSearching(false);
};
return (<>
<Layout>
<Header>
<Spin spinning={loadingStat}>
<h3>使用明细总消耗额度
<span onClick={handleEyeClick} style={{
cursor: 'pointer', color: 'gray'
}}>{showStat ? renderQuota(stat.quota) : '点击查看'}</span>
</h3>
</Spin>
</Header>
<Form layout="horizontal" style={{ marginTop: 10 }}>
<>
<Layout>
<Header>
<Spin spinning={loadingStat}>
<h3>使用明细总消耗额度
<span onClick={handleEyeClick} style={{cursor: 'pointer', color: 'gray'}}>{showStat?renderQuota(stat.quota):"点击查看"}</span>
</h3>
</Spin>
</Header>
<Form layout='horizontal' style={{marginTop: 10}}>
<>
<Form.Input field="token_name" label='令牌名称' style={{width: 176}} value={token_name}
placeholder={'可选值'} name='token_name'
onChange={value => handleInputChange(value, 'token_name')}/>
<Form.Input field="model_name" label='模型名称' style={{width: 176}} value={model_name}
placeholder='可选值'
name='model_name'
onChange={value => handleInputChange(value, 'model_name')}/>
<Form.DatePicker field="start_timestamp" label='起始时间' style={{width: 272}}
initValue={start_timestamp}
value={start_timestamp} type='dateTime'
name='start_timestamp'
onChange={value => handleInputChange(value, 'start_timestamp')}/>
<Form.DatePicker field="end_timestamp" fluid label='结束时间' style={{width: 272}}
initValue={end_timestamp}
value={end_timestamp} type='dateTime'
name='end_timestamp'
onChange={value => handleInputChange(value, 'end_timestamp')}/>
{
isAdminUser && <>
<Form.Input field="channel" label='渠道 ID' style={{width: 176}} value={channel}
placeholder='可选值' name='channel'
onChange={value => handleInputChange(value, 'channel')}/>
<Form.Input field="username" label='用户名称' style={{width: 176}} value={username}
placeholder={'可选值'} name='username'
onChange={value => handleInputChange(value, 'username')}/>
</>
}
<Form.Section>
<Button label='查询' type="primary" htmlType="submit" className="btn-margin-right"
onClick={refresh} loading={loading}>查询</Button>
</Form.Section>
</>
</Form>
<Table style={{marginTop: 5}} columns={columns} dataSource={pageData} pagination={{
currentPage: activePage,
pageSize: ITEMS_PER_PAGE,
total: logCount,
pageSizeOpts: [10, 20, 50, 100],
onPageChange: handlePageChange,
}}/>
<Select defaultValue="0" style={{width: 120}} onChange={
(value) => {
setLogType(parseInt(value));
}
}>
<Select.Option value="0">全部</Select.Option>
<Select.Option value="1">充值</Select.Option>
<Select.Option value="2">消费</Select.Option>
<Select.Option value="3">管理</Select.Option>
<Select.Option value="4">系统</Select.Option>
</Select>
</Layout>
<Form.Input field="token_name" label="令牌名称" style={{ width: 176 }} value={token_name}
placeholder={'可选值'} name="token_name"
onChange={value => handleInputChange(value, 'token_name')} />
<Form.Input field="model_name" label="模型名称" style={{ width: 176 }} value={model_name}
placeholder="可选值"
name="model_name"
onChange={value => handleInputChange(value, 'model_name')} />
<Form.DatePicker field="start_timestamp" label="起始时间" style={{ width: 272 }}
initValue={start_timestamp}
value={start_timestamp} type="dateTime"
name="start_timestamp"
onChange={value => handleInputChange(value, 'start_timestamp')} />
<Form.DatePicker field="end_timestamp" fluid label="结束时间" style={{ width: 272 }}
initValue={end_timestamp}
value={end_timestamp} type="dateTime"
name="end_timestamp"
onChange={value => handleInputChange(value, 'end_timestamp')} />
{isAdminUser && <>
<Form.Input field="channel" label="渠道 ID" style={{ width: 176 }} value={channel}
placeholder="可选值" name="channel"
onChange={value => handleInputChange(value, 'channel')} />
<Form.Input field="username" label="用户名称" style={{ width: 176 }} value={username}
placeholder={'可选值'} name="username"
onChange={value => handleInputChange(value, 'username')} />
</>}
<Form.Section>
<Button label="查询" type="primary" htmlType="submit" className="btn-margin-right"
onClick={refresh} loading={loading}>查询</Button>
</Form.Section>
</>
);
</Form>
<Table style={{ marginTop: 5 }} columns={columns} dataSource={pageData} pagination={{
currentPage: activePage,
pageSize: pageSize,
total: logCount,
pageSizeOpts: [10, 20, 50, 100],
showSizeChanger: true,
onPageSizeChange: (size) => {
handlePageSizeChange(size).then();
},
onPageChange: handlePageChange
}} />
<Select defaultValue="0" style={{ width: 120 }} onChange={(value) => {
setLogType(parseInt(value));
refresh(parseInt(value)).then();
}}>
<Select.Option value="0">全部</Select.Option>
<Select.Option value="1">充值</Select.Option>
<Select.Option value="2">消费</Select.Option>
<Select.Option value="3">管理</Select.Option>
<Select.Option value="4">系统</Select.Option>
</Select>
</Layout>
</>);
};
export default LogsTable;

View File

@@ -1,430 +1,454 @@
import React, {useEffect, useState} from 'react';
import {API, copy, isAdmin, showError, showSuccess, timestamp2string} from '../helpers';
import React, { useEffect, useState } from 'react';
import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers';
import {
Table,
Avatar,
Tag,
Form,
Button,
Layout,
Select,
Popover,
Modal,
ImagePreview,
Typography, Progress
} from '@douyinfe/semi-ui';
import {ITEMS_PER_PAGE} from '../constants';
import {renderNumber, renderQuota, stringToColor} from '../helpers/render';
import { Banner, Button, Form, ImagePreview, Layout, Modal, Progress, Table, Tag, Typography } from '@douyinfe/semi-ui';
import { ITEMS_PER_PAGE } from '../constants';
const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
'light-blue', 'lime', 'orange', 'pink',
'purple', 'red', 'teal', 'violet', 'yellow'
]
'light-blue', 'lime', 'orange', 'pink',
'purple', 'red', 'teal', 'violet', 'yellow'
];
function renderType(type) {
switch (type) {
case 'IMAGINE':
return <Tag color="blue" size='large'>绘图</Tag>;
case 'UPSCALE':
return <Tag color="orange" size='large'>放大</Tag>;
case 'VARIATION':
return <Tag color="purple" size='large'>变换</Tag>;
case 'DESCRIBE':
return <Tag color="yellow" size='large'>图生文</Tag>;
case 'BLEAND':
return <Tag color="lime" size='large'>图混合</Tag>;
default:
return <Tag color="white" size='large'>未知</Tag>;
}
switch (type) {
case 'IMAGINE':
return <Tag color="blue" size="large">绘图</Tag>;
case 'UPSCALE':
return <Tag color="orange" size="large">放大</Tag>;
case 'VARIATION':
return <Tag color="purple" size="large">变换</Tag>;
case 'HIGH_VARIATION':
return <Tag color="purple" size="large">强变换</Tag>;
case 'LOW_VARIATION':
return <Tag color="purple" size="large">弱变换</Tag>;
case 'PAN':
return <Tag color="cyan" size="large">平移</Tag>;
case 'DESCRIBE':
return <Tag color="yellow" size="large">图生文</Tag>;
case 'BLEND':
return <Tag color="lime" size="large">图混合</Tag>;
case 'SHORTEN':
return <Tag color="pink" size="large">缩词</Tag>;
case 'REROLL':
return <Tag color="indigo" size="large">重绘</Tag>;
case 'INPAINT':
return <Tag color="violet" size="large">局部重绘-提交</Tag>;
case 'ZOOM':
return <Tag color="teal" size="large">变焦</Tag>;
case 'CUSTOM_ZOOM':
return <Tag color="teal" size="large">自定义变焦-提交</Tag>;
case 'MODAL':
return <Tag color="green" size="large">窗口处理</Tag>;
case 'SWAP_FACE':
return <Tag color="light-green" size="large">换脸</Tag>;
default:
return <Tag color="white" size="large">未知</Tag>;
}
}
function renderCode(code) {
switch (code) {
case 1:
return <Tag color="green" size='large'>已提交</Tag>;
case 21:
return <Tag color="lime" size='large'>排队</Tag>;
case 22:
return <Tag color="orange" size='large'>重复提交</Tag>;
default:
return <Tag color="white" size='large'></Tag>;
}
switch (code) {
case 1:
return <Tag color="green" size="large">已提交</Tag>;
case 21:
return <Tag color="lime" size="large">等待</Tag>;
case 22:
return <Tag color="orange" size="large">重复提交</Tag>;
case 0:
return <Tag color="yellow" size="large">提交</Tag>;
default:
return <Tag color="white" size="large">未知</Tag>;
}
}
function renderStatus(type) {
// Ensure all cases are string literals by adding quotes.
switch (type) {
case 'SUCCESS':
return <Tag color="green" size='large'>成功</Tag>;
case 'NOT_START':
return <Tag color="grey" size='large'>未启动</Tag>;
case 'SUBMITTED':
return <Tag color="yellow" size='large'>队列中</Tag>;
case 'IN_PROGRESS':
return <Tag color="blue" size='large'>执行中</Tag>;
case 'FAILURE':
return <Tag color="red" size='large'>失败</Tag>;
default:
return <Tag color="white" size='large'>未知</Tag>;
}
// Ensure all cases are string literals by adding quotes.
switch (type) {
case 'SUCCESS':
return <Tag color="green" size="large">成功</Tag>;
case 'NOT_START':
return <Tag color="grey" size="large">未启动</Tag>;
case 'SUBMITTED':
return <Tag color="yellow" size="large">队列中</Tag>;
case 'IN_PROGRESS':
return <Tag color="blue" size="large">执行中</Tag>;
case 'FAILURE':
return <Tag color="red" size="large">失败</Tag>;
case 'MODAL':
return <Tag color="yellow" size="large">窗口等待</Tag>;
default:
return <Tag color="white" size="large">未知</Tag>;
}
}
const renderTimestamp = (timestampInSeconds) => {
const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒
const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒
const year = date.getFullYear(); // 获取年份
const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份从0开始需要+1并保证两位数
const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数
const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数
const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数
const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数
const year = date.getFullYear(); // 获取年份
const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份从0开始需要+1并保证两位数
const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数
const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数
const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数
const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
};
const LogsTable = () => {
const [isModalOpen, setIsModalOpen] = useState(false);
const [modalContent, setModalContent] = useState('');
const columns = [
{
title: '提交时间',
dataIndex: 'submit_time',
render: (text, record, index) => {
return (
<div>
{renderTimestamp(text / 1000)}
</div>
);
},
},
{
title: '渠道',
dataIndex: 'channel_id',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
const [isModalOpen, setIsModalOpen] = useState(false);
const [modalContent, setModalContent] = useState('');
const columns = [
{
title: '提交时间',
dataIndex: 'submit_time',
render: (text, record, index) => {
return (
<div>
{renderTimestamp(text / 1000)}
</div>
);
}
},
{
title: '渠道',
dataIndex: 'channel_id',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
<div>
<Tag color={colors[parseInt(text) % colors.length]} size='large' onClick={() => {
copyText(text); // 假设copyText是用于文本复制的函数
}}> {text} </Tag>
</div>
<div>
<Tag color={colors[parseInt(text) % colors.length]} size="large" onClick={() => {
copyText(text); // 假设copyText是用于文本复制的函数
}}> {text} </Tag>
</div>
);
},
},
{
title: '类型',
dataIndex: 'action',
render: (text, record, index) => {
return (
<div>
{renderType(text)}
</div>
);
},
},
{
title: '任务ID',
dataIndex: 'mj_id',
render: (text, record, index) => {
return (
<div>
{text}
</div>
);
},
},
{
title: '提交结果',
dataIndex: 'code',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
<div>
{renderCode(text)}
</div>
);
},
},
{
title: '任务状态',
dataIndex: 'status',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
<div>
{renderStatus(text)}
</div>
);
},
},
{
title: '进度',
dataIndex: 'progress',
render: (text, record, index) => {
return (
<div>
{
// 转换例如100%为数字100如果text未定义返回0
<Progress stroke={record.status === "FAILURE"?"var(--semi-color-warning)":null} percent={text ? parseInt(text.replace('%', '')) : 0} showInfo={true}
aria-label="drawing progress"/>
}
</div>
);
},
},
{
title: '结果图片',
dataIndex: 'image_url',
render: (text, record, index) => {
if (!text) {
return '无';
}
return (
<Button
onClick={() => {
setModalImageUrl(text); // 更新图片URL状态
setIsModalOpenurl(true); // 打开模态框
}}
>
查看图片
</Button>
);
}
},
{
title: 'Prompt',
dataIndex: 'prompt',
render: (text, record, index) => {
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
return (
<Typography.Text
ellipsis={{showTooltip: true}}
style={{width: 100}}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
},
{
title: 'PromptEn',
dataIndex: 'prompt_en',
render: (text, record, index) => {
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
return (
<Typography.Text
ellipsis={{showTooltip: true}}
style={{width: 100}}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
},
{
title: '失败原因',
dataIndex: 'fail_reason',
render: (text, record, index) => {
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
return (
<Typography.Text
ellipsis={{showTooltip: true}}
style={{width: 100}}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
);
}
},
{
title: '类型',
dataIndex: 'action',
render: (text, record, index) => {
return (
<div>
{renderType(text)}
</div>
);
}
},
{
title: '任务ID',
dataIndex: 'mj_id',
render: (text, record, index) => {
return (
<div>
{text}
</div>
);
}
},
{
title: '提交结果',
dataIndex: 'code',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
<div>
{renderCode(text)}
</div>
);
}
},
{
title: '任务状态',
dataIndex: 'status',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
<div>
{renderStatus(text)}
</div>
);
}
},
{
title: '进度',
dataIndex: 'progress',
render: (text, record, index) => {
return (
<div>
{
// 转换例如100%为数字100如果text未定义返回0
<Progress stroke={record.status === 'FAILURE' ? 'var(--semi-color-warning)' : null}
percent={text ? parseInt(text.replace('%', '')) : 0} showInfo={true}
aria-label="drawing progress" />
}
</div>
);
}
},
{
title: '结果图片',
dataIndex: 'image_url',
render: (text, record, index) => {
if (!text) {
return '无';
}
return (
<Button
onClick={() => {
setModalImageUrl(text); // 更新图片URL状态
setIsModalOpenurl(true); // 打开模态框
}}
>
查看图片
</Button>
);
}
},
{
title: 'Prompt',
dataIndex: 'prompt',
render: (text, record, index) => {
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
];
const [logs, setLogs] = useState([]);
const [loading, setLoading] = useState(true);
const [activePage, setActivePage] = useState(1);
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
const [logType, setLogType] = useState(0);
const isAdminUser = isAdmin();
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
// 定义模态框图片URL的状态和更新函数
const [modalImageUrl, setModalImageUrl] = useState('');
let now = new Date();
// 初始化start_timestamp为前一天
const [inputs, setInputs] = useState({
channel_id: '',
mj_id: '',
start_timestamp: timestamp2string(now.getTime() / 1000 - 2592000),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
});
const {channel_id, mj_id, start_timestamp, end_timestamp} = inputs;
const [stat, setStat] = useState({
quota: 0,
token: 0
});
const handleInputChange = (value, name) => {
setInputs((inputs) => ({...inputs, [name]: value}));
};
const setLogsFormat = (logs) => {
for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
logs[i].key = '' + logs[i].id;
return (
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
},
{
title: 'PromptEn',
dataIndex: 'prompt_en',
render: (text, record, index) => {
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
// data.key = '' + data.id
setLogs(logs);
setLogCount(logs.length + ITEMS_PER_PAGE);
// console.log(logCount);
return (
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
},
{
title: '失败原因',
dataIndex: 'fail_reason',
render: (text, record, index) => {
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
return (
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
}
const loadLogs = async (startIdx) => {
setLoading(true);
];
let url = '';
let localStartTimestamp = Date.parse(start_timestamp);
let localEndTimestamp = Date.parse(end_timestamp);
if (isAdminUser) {
url = `/api/mj/?p=${startIdx}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
} else {
url = `/api/mj/self/?p=${startIdx}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
const res = await API.get(url);
const {success, message, data} = res.data;
if (success) {
if (startIdx === 0) {
setLogsFormat(data);
} else {
let newLogs = [...logs];
newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
setLogsFormat(newLogs);
}
} else {
showError(message);
}
setLoading(false);
};
const [logs, setLogs] = useState([]);
const [loading, setLoading] = useState(true);
const [activePage, setActivePage] = useState(1);
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
const [logType, setLogType] = useState(0);
const isAdminUser = isAdmin();
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
const [showBanner, setShowBanner] = useState(false);
const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE);
// 定义模态框图片URL的状态和更新函数
const [modalImageUrl, setModalImageUrl] = useState('');
let now = new Date();
// 初始化start_timestamp为前一天
const [inputs, setInputs] = useState({
channel_id: '',
mj_id: '',
start_timestamp: timestamp2string(now.getTime() / 1000 - 2592000),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
});
const { channel_id, mj_id, start_timestamp, end_timestamp } = inputs;
const handlePageChange = page => {
setActivePage(page);
if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
loadLogs(page - 1).then(r => {
});
}
};
const [stat, setStat] = useState({
quota: 0,
token: 0
});
const refresh = async () => {
// setLoading(true);
setActivePage(1);
await loadLogs(0);
};
const handleInputChange = (value, name) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
const copyText = async (text) => {
if (await copy(text)) {
showSuccess('已复制:' + text);
} else {
// setSearchKeyword(text);
Modal.error({title: '无法复制到剪贴板,请手动复制', content: text});
}
const setLogsFormat = (logs) => {
for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
logs[i].key = '' + logs[i].id;
}
// data.key = '' + data.id
setLogs(logs);
setLogCount(logs.length + ITEMS_PER_PAGE);
// console.log(logCount);
};
useEffect(() => {
refresh().then();
}, [logType]);
const loadLogs = async (startIdx) => {
setLoading(true);
let url = '';
let localStartTimestamp = Date.parse(start_timestamp);
let localEndTimestamp = Date.parse(end_timestamp);
if (isAdminUser) {
url = `/api/mj/?p=${startIdx}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
} else {
url = `/api/mj/self/?p=${startIdx}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
const res = await API.get(url);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
setLogsFormat(data);
} else {
let newLogs = [...logs];
newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
setLogsFormat(newLogs);
}
} else {
showError(message);
}
setLoading(false);
};
return (
<>
const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE);
<Layout>
<Form layout='horizontal' style={{marginTop: 10}}>
<>
<Form.Input field="channel_id" label='渠道 ID' style={{width: 176}} value={channel_id}
placeholder={'可选值'} name='channel_id'
onChange={value => handleInputChange(value, 'channel_id')}/>
<Form.Input field="mj_id" label='任务 ID' style={{width: 176}} value={mj_id}
placeholder='可选值'
name='mj_id'
onChange={value => handleInputChange(value, 'mj_id')}/>
<Form.DatePicker field="start_timestamp" label='起始时间' style={{width: 272}}
initValue={start_timestamp}
value={start_timestamp} type='dateTime'
name='start_timestamp'
onChange={value => handleInputChange(value, 'start_timestamp')}/>
<Form.DatePicker field="end_timestamp" fluid label='结束时间' style={{width: 272}}
initValue={end_timestamp}
value={end_timestamp} type='dateTime'
name='end_timestamp'
onChange={value => handleInputChange(value, 'end_timestamp')}/>
const handlePageChange = page => {
setActivePage(page);
if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
loadLogs(page - 1).then(r => {
});
}
};
<Form.Section>
<Button label='查询' type="primary" htmlType="submit" className="btn-margin-right"
onClick={refresh}>查询</Button>
</Form.Section>
</>
</Form>
<Table style={{marginTop: 5}} columns={columns} dataSource={pageData} pagination={{
currentPage: activePage,
pageSize: ITEMS_PER_PAGE,
total: logCount,
pageSizeOpts: [10, 20, 50, 100],
onPageChange: handlePageChange,
}} loading={loading}/>
<Modal
visible={isModalOpen}
onOk={() => setIsModalOpen(false)}
onCancel={() => setIsModalOpen(false)}
closable={null}
bodyStyle={{height: '400px', overflow: 'auto'}} // 设置模态框内容区域样式
width={800} // 设置模态框宽度
>
<p style={{whiteSpace: 'pre-line'}}>{modalContent}</p>
</Modal>
<ImagePreview
src={modalImageUrl}
visible={isModalOpenurl}
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
/>
const refresh = async () => {
// setLoading(true);
setActivePage(1);
await loadLogs(0);
};
</Layout>
</>
);
const copyText = async (text) => {
if (await copy(text)) {
showSuccess('已复制:' + text);
} else {
// setSearchKeyword(text);
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
}
};
useEffect(() => {
refresh().then();
}, [logType]);
useEffect(() => {
const mjNotifyEnabled = localStorage.getItem('mj_notify_enabled');
if (mjNotifyEnabled !== 'true') {
setShowBanner(true);
}
}, []);
return (
<>
<Layout>
{isAdminUser && showBanner ? <Banner
type="info"
description="当前未开启Midjourney回调部分项目可能无法获得绘图结果可在运营设置中开启。"
/> : <></>
}
<Form layout="horizontal" style={{ marginTop: 10 }}>
<>
<Form.Input field="channel_id" label="渠道 ID" style={{ width: 176 }} value={channel_id}
placeholder={'可选值'} name="channel_id"
onChange={value => handleInputChange(value, 'channel_id')} />
<Form.Input field="mj_id" label="任务 ID" style={{ width: 176 }} value={mj_id}
placeholder="可选值"
name="mj_id"
onChange={value => handleInputChange(value, 'mj_id')} />
<Form.DatePicker field="start_timestamp" label="起始时间" style={{ width: 272 }}
initValue={start_timestamp}
value={start_timestamp} type="dateTime"
name="start_timestamp"
onChange={value => handleInputChange(value, 'start_timestamp')} />
<Form.DatePicker field="end_timestamp" fluid label="结束时间" style={{ width: 272 }}
initValue={end_timestamp}
value={end_timestamp} type="dateTime"
name="end_timestamp"
onChange={value => handleInputChange(value, 'end_timestamp')} />
<Form.Section>
<Button label="查询" type="primary" htmlType="submit" className="btn-margin-right"
onClick={refresh}>查询</Button>
</Form.Section>
</>
</Form>
<Table style={{ marginTop: 5 }} columns={columns} dataSource={pageData} pagination={{
currentPage: activePage,
pageSize: ITEMS_PER_PAGE,
total: logCount,
pageSizeOpts: [10, 20, 50, 100],
onPageChange: handlePageChange
}} loading={loading} />
<Modal
visible={isModalOpen}
onOk={() => setIsModalOpen(false)}
onCancel={() => setIsModalOpen(false)}
closable={null}
bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式
width={800} // 设置模态框宽度
>
<p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
</Modal>
<ImagePreview
src={modalImageUrl}
visible={isModalOpenurl}
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
/>
</Layout>
</>
);
};
export default LogsTable;

View File

@@ -1,453 +1,468 @@
import React, {useEffect, useState} from 'react';
import {Divider, Form, Grid, Header} from 'semantic-ui-react';
import {API, showError, showSuccess, timestamp2string, verifyJSON} from '../helpers';
import React, { useEffect, useState } from 'react';
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
const OperationSetting = () => {
let now = new Date();
let [inputs, setInputs] = useState({
QuotaForNewUser: 0,
QuotaForInviter: 0,
QuotaForInvitee: 0,
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
ModelRatio: '',
ModelPrice: '',
GroupRatio: '',
TopUpLink: '',
ChatLink: '',
ChatLink2: '', // 添加的新状态变量
QuotaPerUnit: 0,
AutomaticDisableChannelEnabled: '',
AutomaticEnableChannelEnabled: '',
ChannelDisableThreshold: 0,
LogConsumeEnabled: '',
DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '',
DrawingEnabled: '',
DataExportEnabled: '',
DataExportDefaultTime: 'hour',
DataExportInterval: 5,
DefaultCollapseSidebar: '', // 默认折叠侧边栏
RetryTimes: 0
let now = new Date();
let [inputs, setInputs] = useState({
QuotaForNewUser: 0,
QuotaForInviter: 0,
QuotaForInvitee: 0,
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
ModelRatio: '',
ModelPrice: '',
GroupRatio: '',
TopUpLink: '',
ChatLink: '',
ChatLink2: '', // 添加的新状态变量
QuotaPerUnit: 0,
AutomaticDisableChannelEnabled: '',
AutomaticEnableChannelEnabled: '',
ChannelDisableThreshold: 0,
LogConsumeEnabled: '',
DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '',
MjNotifyEnabled: '',
DrawingEnabled: '',
DataExportEnabled: '',
DataExportDefaultTime: 'hour',
DataExportInterval: 5,
DefaultCollapseSidebar: '', // 默认折叠侧边栏
RetryTimes: 0
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
// 精确时间选项(小时,天,周)
const timeOptions = [
{ key: 'hour', text: '小时', value: 'hour' },
{ key: 'day', text: '天', value: 'day' },
{ key: 'week', text: '周', value: 'week' }
];
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'ModelPrice') {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
newInputs[item.key] = item.value;
});
setInputs(newInputs);
setOriginInputs(newInputs);
} else {
showError(message);
}
};
useEffect(() => {
getOptions().then();
}, []);
const updateOption = async (key, value) => {
setLoading(true);
if (key.endsWith('Enabled')) {
value = inputs[key] === 'true' ? 'false' : 'true';
}
if (key === 'DefaultCollapseSidebar') {
value = inputs[key] === 'true' ? 'false' : 'true';
}
console.log(key, value);
const res = await API.put('/api/option/', {
key,
value
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
// 精确时间选项(小时,天,周)
const timeOptions = [
{key: 'hour', text: '小时', value: 'hour'},
{key: 'day', text: '天', value: 'day'},
{key: 'week', text: '周', value: 'week'}
];
const getOptions = async () => {
const res = await API.get('/api/option/');
const {success, message, data} = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'ModelPrice') {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
newInputs[item.key] = item.value;
});
setInputs(newInputs);
setOriginInputs(newInputs);
} else {
showError(message);
}
};
const { success, message } = res.data;
if (success) {
setInputs((inputs) => ({ ...inputs, [key]: value }));
} else {
showError(message);
}
setLoading(false);
};
useEffect(() => {
getOptions().then();
}, []);
const handleInputChange = async (e, { name, value }) => {
if (name.endsWith('Enabled') || name === 'DataExportInterval' || name === 'DataExportDefaultTime' || name === 'DefaultCollapseSidebar') {
if (name === 'DataExportDefaultTime') {
localStorage.setItem('data_export_default_time', value);
} else if (name === 'MjNotifyEnabled') {
localStorage.setItem('mj_notify_enabled', value);
}
await updateOption(name, value);
} else {
setInputs((inputs) => ({ ...inputs, [name]: value }));
}
};
const updateOption = async (key, value) => {
setLoading(true);
if (key.endsWith('Enabled')) {
value = inputs[key] === 'true' ? 'false' : 'true';
const submitConfig = async (group) => {
switch (group) {
case 'monitor':
if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
}
if (key === 'DefaultCollapseSidebar') {
value = inputs[key] === 'true' ? 'false' : 'true';
if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
}
console.log(key, value)
const res = await API.put('/api/option/', {
key,
value
});
const {success, message} = res.data;
if (success) {
setInputs((inputs) => ({...inputs, [key]: value}));
} else {
showError(message);
}
setLoading(false);
};
const handleInputChange = async (e, {name, value}) => {
if (name.endsWith('Enabled') || name === 'DataExportInterval' || name === 'DataExportDefaultTime' || name === 'DefaultCollapseSidebar') {
if (name === 'DataExportDefaultTime') {
localStorage.setItem('data_export_default_time', value);
}
await updateOption(name, value);
} else {
setInputs((inputs) => ({...inputs, [name]: value}));
}
};
const submitConfig = async (group) => {
switch (group) {
case 'monitor':
if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
}
if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
}
break;
case 'ratio':
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
if (!verifyJSON(inputs.ModelRatio)) {
showError('模型倍率不是合法的 JSON 字符串');
return;
}
await updateOption('ModelRatio', inputs.ModelRatio);
}
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
return;
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
if (originInputs['ModelPrice'] !== inputs.ModelPrice) {
if (!verifyJSON(inputs.ModelPrice)) {
showError('模型固定价格不是合法的 JSON 字符串');
return;
}
await updateOption('ModelPrice', inputs.ModelPrice);
}
break;
case 'quota':
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
}
if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
}
if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
await updateOption('QuotaForInviter', inputs.QuotaForInviter);
}
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
}
break;
case 'general':
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
await updateOption('TopUpLink', inputs.TopUpLink);
}
if (originInputs['ChatLink'] !== inputs.ChatLink) {
await updateOption('ChatLink', inputs.ChatLink);
}
if (originInputs['ChatLink2'] !== inputs.ChatLink2) {
await updateOption('ChatLink2', inputs.ChatLink2);
}
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
}
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
await updateOption('RetryTimes', inputs.RetryTimes);
}
break;
}
};
const deleteHistoryLogs = async () => {
console.log(inputs);
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
const {success, message, data} = res.data;
if (success) {
showSuccess(`${data} 条日志已清理!`);
break;
case 'ratio':
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
if (!verifyJSON(inputs.ModelRatio)) {
showError('模型倍率不是合法的 JSON 字符串');
return;
}
await updateOption('ModelRatio', inputs.ModelRatio);
}
showError('日志清理失败:' + message);
};
return (
<Grid columns={1}>
<Grid.Column>
<Form loading={loading}>
<Header as='h3'>
通用设置
</Header>
<Form.Group widths={4}>
<Form.Input
label='充值链接'
name='TopUpLink'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.TopUpLink}
type='link'
placeholder='例如发卡网站的购买链接'
/>
<Form.Input
label='默认聊天页面链接'
name='ChatLink'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.ChatLink}
type='link'
placeholder='例如 ChatGPT Next Web 的部署地址'
/>
<Form.Input
label='聊天页面2链接'
name='ChatLink2'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.ChatLink2}
type='link'
placeholder='例如 ChatGPT Web & Midjourney 的部署地址'
/>
<Form.Input
label='单位美元额度'
name='QuotaPerUnit'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.QuotaPerUnit}
type='number'
step='0.01'
placeholder='一单位货币能兑换的额度'
/>
<Form.Input
label='失败重试次数'
name='RetryTimes'
type={'number'}
step='1'
min='0'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.RetryTimes}
placeholder='失败重试次数'
/>
</Form.Group>
<Form.Group inline>
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
return;
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
if (originInputs['ModelPrice'] !== inputs.ModelPrice) {
if (!verifyJSON(inputs.ModelPrice)) {
showError('模型固定价格不是合法的 JSON 字符串');
return;
}
await updateOption('ModelPrice', inputs.ModelPrice);
}
break;
case 'quota':
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
}
if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
}
if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
await updateOption('QuotaForInviter', inputs.QuotaForInviter);
}
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
}
break;
case 'general':
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
await updateOption('TopUpLink', inputs.TopUpLink);
}
if (originInputs['ChatLink'] !== inputs.ChatLink) {
await updateOption('ChatLink', inputs.ChatLink);
}
if (originInputs['ChatLink2'] !== inputs.ChatLink2) {
await updateOption('ChatLink2', inputs.ChatLink2);
}
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
}
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
await updateOption('RetryTimes', inputs.RetryTimes);
}
break;
}
};
<Form.Checkbox
checked={inputs.DisplayInCurrencyEnabled === 'true'}
label='以货币形式显示额度'
name='DisplayInCurrencyEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.DisplayTokenStatEnabled === 'true'}
label='Billing 相关 API 显示令牌额度而非用户额度'
name='DisplayTokenStatEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.DrawingEnabled === 'true'}
label='启用绘图功能'
name='DrawingEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.DefaultCollapseSidebar === 'true'}
label='默认折叠侧边栏'
name='DefaultCollapseSidebar'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('general').then();
}}>保存通用设置</Form.Button><Divider/>
<Header as='h3'>
日志设置
</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label='启用额度消费日志记录'
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group widths={4}>
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
name='history_timestamp'
onChange={(e, {name, value}) => {
setHistoryTimestamp(value);
}}/>
</Form.Group>
<Form.Button onClick={() => {
deleteHistoryLogs().then();
}}>清理历史日志</Form.Button>
<Divider/>
<Header as='h3'>
数据看板
</Header>
<Form.Checkbox
checked={inputs.DataExportEnabled === 'true'}
label='启用数据看板(实验性)'
name='DataExportEnabled'
onChange={handleInputChange}
/>
<Form.Group>
<Form.Input
label='数据看板更新间隔(分钟,设置过短会影响数据库性能)'
name='DataExportInterval'
type={'number'}
step='1'
min='1'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.DataExportInterval}
placeholder='数据看板更新间隔(分钟,设置过短会影响数据库性能)'
/>
<Form.Select
label='数据看板默认时间粒度(仅修改展示粒度,统计精确到小时)'
options={timeOptions}
name='DataExportDefaultTime'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.DataExportDefaultTime}
placeholder='数据看板默认时间粒度'
/>
</Form.Group>
<Divider/>
<Header as='h3'>
监控设置
</Header>
<Form.Group widths={3}>
<Form.Input
label='最长响应时间'
name='ChannelDisableThreshold'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.ChannelDisableThreshold}
type='number'
min='0'
placeholder='单位秒,当运行通道全部测试时,超过此时间将自动禁用通道'
/>
<Form.Input
label='额度提醒阈值'
name='QuotaRemindThreshold'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.QuotaRemindThreshold}
type='number'
min='0'
placeholder='低于此额度时将发送邮件提醒用户'
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox
checked={inputs.AutomaticDisableChannelEnabled === 'true'}
label='失败时自动禁用通道'
name='AutomaticDisableChannelEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.AutomaticEnableChannelEnabled === 'true'}
label='成功时自动启用通道'
name='AutomaticEnableChannelEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('monitor').then();
}}>保存监控设置</Form.Button>
<Divider/>
<Header as='h3'>
额度设置
</Header>
<Form.Group widths={4}>
<Form.Input
label='新用户初始额度'
name='QuotaForNewUser'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.QuotaForNewUser}
type='number'
min='0'
placeholder='例如100'
/>
<Form.Input
label='请求预扣费额度'
name='PreConsumedQuota'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.PreConsumedQuota}
type='number'
min='0'
placeholder='请求结束后多退少补'
/>
<Form.Input
label='邀请新用户奖励额度'
name='QuotaForInviter'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.QuotaForInviter}
type='number'
min='0'
placeholder='例如2000'
/>
<Form.Input
label='新用户使用邀请码奖励额度'
name='QuotaForInvitee'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.QuotaForInvitee}
type='number'
min='0'
placeholder='例如1000'
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('quota').then();
}}>保存额度设置</Form.Button>
<Divider/>
<Header as='h3'>
倍率设置
</Header>
<Form.Group widths='equal'>
<Form.TextArea
label='模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)'
name='ModelPrice'
onChange={handleInputChange}
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
autoComplete='new-password'
value={inputs.ModelPrice}
placeholder='为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1一次消耗0.1刀'
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
label='模型倍率'
name='ModelRatio'
onChange={handleInputChange}
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
autoComplete='new-password'
value={inputs.ModelRatio}
placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
label='分组倍率'
name='GroupRatio'
onChange={handleInputChange}
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
autoComplete='new-password'
value={inputs.GroupRatio}
placeholder='为一个 JSON 文本,键为分组名称,值为倍率'
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('ratio').then();
}}>保存倍率设置</Form.Button>
</Form>
</Grid.Column>
</Grid>
)
;
const deleteHistoryLogs = async () => {
console.log(inputs);
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
const { success, message, data } = res.data;
if (success) {
showSuccess(`${data} 条日志已清理!`);
return;
}
showError('日志清理失败:' + message);
};
return (
<Grid columns={1}>
<Grid.Column>
<Form loading={loading}>
<Header as="h3">
通用设置
</Header>
<Form.Group widths={4}>
<Form.Input
label="充值链接"
name="TopUpLink"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.TopUpLink}
type="link"
placeholder="例如发卡网站的购买链接"
/>
<Form.Input
label="默认聊天页面链接"
name="ChatLink"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.ChatLink}
type="link"
placeholder="例如 ChatGPT Next Web 的部署地址"
/>
<Form.Input
label="聊天页面2链接"
name="ChatLink2"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.ChatLink2}
type="link"
placeholder="例如 ChatGPT Web & Midjourney 的部署地址"
/>
<Form.Input
label="单位美元额度"
name="QuotaPerUnit"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.QuotaPerUnit}
type="number"
step="0.01"
placeholder="一单位货币能兑换的额度"
/>
<Form.Input
label="失败重试次数"
name="RetryTimes"
type={'number'}
step="1"
min="0"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.RetryTimes}
placeholder="失败重试次数"
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox
checked={inputs.DisplayInCurrencyEnabled === 'true'}
label="以货币形式显示额度"
name="DisplayInCurrencyEnabled"
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.DisplayTokenStatEnabled === 'true'}
label="Billing 相关 API 显示令牌额度而非用户额度"
name="DisplayTokenStatEnabled"
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.DefaultCollapseSidebar === 'true'}
label="默认折叠侧边栏"
name="DefaultCollapseSidebar"
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('general').then();
}}>保存通用设置</Form.Button>
<Divider />
<Header as="h3">
绘图设置
</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.DrawingEnabled === 'true'}
label="启用绘图功能"
name="DrawingEnabled"
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.MjNotifyEnabled === 'true'}
label="允许回调会泄露服务器ip地址"
name="MjNotifyEnabled"
onChange={handleInputChange}
/>
</Form.Group>
<Divider />
<Header as="h3">
日志设置
</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label="启用额度消费日志记录"
name="LogConsumeEnabled"
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group widths={4}>
<Form.Input label="目标时间" value={historyTimestamp} type="datetime-local"
name="history_timestamp"
onChange={(e, { name, value }) => {
setHistoryTimestamp(value);
}} />
</Form.Group>
<Form.Button onClick={() => {
deleteHistoryLogs().then();
}}>清理历史日志</Form.Button>
<Divider />
<Header as="h3">
数据看板
</Header>
<Form.Checkbox
checked={inputs.DataExportEnabled === 'true'}
label="启用数据看板(实验性)"
name="DataExportEnabled"
onChange={handleInputChange}
/>
<Form.Group>
<Form.Input
label="数据看板更新间隔(分钟,设置过短会影响数据库性能)"
name="DataExportInterval"
type={'number'}
step="1"
min="1"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.DataExportInterval}
placeholder="数据看板更新间隔(分钟,设置过短会影响数据库性能)"
/>
<Form.Select
label="数据看板默认时间粒度(仅修改展示粒度,统计精确到小时)"
options={timeOptions}
name="DataExportDefaultTime"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.DataExportDefaultTime}
placeholder="数据看板默认时间粒度"
/>
</Form.Group>
<Divider />
<Header as="h3">
监控设置
</Header>
<Form.Group widths={3}>
<Form.Input
label="最长响应时间"
name="ChannelDisableThreshold"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.ChannelDisableThreshold}
type="number"
min="0"
placeholder="单位秒,当运行通道全部测试时,超过此时间将自动禁用通道"
/>
<Form.Input
label="额度提醒阈值"
name="QuotaRemindThreshold"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.QuotaRemindThreshold}
type="number"
min="0"
placeholder="低于此额度时将发送邮件提醒用户"
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox
checked={inputs.AutomaticDisableChannelEnabled === 'true'}
label="失败时自动禁用通道"
name="AutomaticDisableChannelEnabled"
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.AutomaticEnableChannelEnabled === 'true'}
label="成功时自动启用通道"
name="AutomaticEnableChannelEnabled"
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('monitor').then();
}}>保存监控设置</Form.Button>
<Divider />
<Header as="h3">
额度设置
</Header>
<Form.Group widths={4}>
<Form.Input
label="新用户初始额度"
name="QuotaForNewUser"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.QuotaForNewUser}
type="number"
min="0"
placeholder="例如100"
/>
<Form.Input
label="请求预扣费额度"
name="PreConsumedQuota"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.PreConsumedQuota}
type="number"
min="0"
placeholder="请求结束后多退少补"
/>
<Form.Input
label="邀请新用户奖励额度"
name="QuotaForInviter"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.QuotaForInviter}
type="number"
min="0"
placeholder="例如2000"
/>
<Form.Input
label="新用户使用邀请码奖励额度"
name="QuotaForInvitee"
onChange={handleInputChange}
autoComplete="new-password"
value={inputs.QuotaForInvitee}
type="number"
min="0"
placeholder="例如1000"
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('quota').then();
}}>保存额度设置</Form.Button>
<Divider />
<Header as="h3">
倍率设置
</Header>
<Form.Group widths="equal">
<Form.TextArea
label="模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)"
name="ModelPrice"
onChange={handleInputChange}
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete="new-password"
value={inputs.ModelPrice}
placeholder='为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1一次消耗0.1刀'
/>
</Form.Group>
<Form.Group widths="equal">
<Form.TextArea
label="模型倍率"
name="ModelRatio"
onChange={handleInputChange}
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete="new-password"
value={inputs.ModelRatio}
placeholder="为一个 JSON 文本,键为模型名称,值为倍率"
/>
</Form.Group>
<Form.Group widths="equal">
<Form.TextArea
label="分组倍率"
name="GroupRatio"
onChange={handleInputChange}
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete="new-password"
value={inputs.GroupRatio}
placeholder="为一个 JSON 文本,键为分组名称,值为倍率"
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('ratio').then();
}}>保存倍率设置</Form.Button>
</Form>
</Grid.Column>
</Grid>
)
;
};
export default OperationSetting;

View File

@@ -1,15 +1,15 @@
import React, { useEffect, useState } from 'react';
import { Button, Divider, Form, Grid, Header, Message, Modal } from 'semantic-ui-react';
import React, { useEffect, useRef, useState } from 'react';
import { Banner, Button, Col, Form, Row } from '@douyinfe/semi-ui';
import { API, showError, showSuccess } from '../helpers';
import { marked } from 'marked';
const OtherSetting = () => {
let [inputs, setInputs] = useState({
Footer: '',
Notice: '',
About: '',
SystemName: '',
Logo: '',
Footer: '',
About: '',
HomePageContent: ''
});
let [loading, setLoading] = useState(false);
@@ -19,25 +19,6 @@ const OtherSetting = () => {
content: ''
});
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key in inputs) {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
useEffect(() => {
getOptions().then();
}, []);
const updateOption = async (key, value) => {
setLoading(true);
@@ -54,33 +35,103 @@ const OtherSetting = () => {
setLoading(false);
};
const handleInputChange = async (e, { name, value }) => {
const [loadingInput, setLoadingInput] = useState({
Notice: false,
SystemName: false,
Logo: false,
HomePageContent: false,
About: false,
Footer: false
});
const handleInputChange = async (value, e) => {
const name = e.target.id;
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
// 通用设置
const formAPISettingGeneral = useRef();
// 通用设置 - Notice
const submitNotice = async () => {
await updateOption('Notice', inputs.Notice);
try {
setLoadingInput((loadingInput) => ({ ...loadingInput, Notice: true }));
await updateOption('Notice', inputs.Notice);
showSuccess('公告已更新');
} catch (error) {
console.error('公告更新失败', error);
showError('公告更新失败');
} finally {
setLoadingInput((loadingInput) => ({ ...loadingInput, Notice: false }));
}
};
const submitFooter = async () => {
await updateOption('Footer', inputs.Footer);
};
// 个性化设置
const formAPIPersonalization = useRef();
// 个性化设置 - SystemName
const submitSystemName = async () => {
await updateOption('SystemName', inputs.SystemName);
try {
setLoadingInput((loadingInput) => ({ ...loadingInput, SystemName: true }));
await updateOption('SystemName', inputs.SystemName);
showSuccess('系统名称已更新');
} catch (error) {
console.error('系统名称更新失败', error);
showError('系统名称更新失败');
} finally {
setLoadingInput((loadingInput) => ({ ...loadingInput, SystemName: false }));
}
};
// 个性化设置 - Logo
const submitLogo = async () => {
await updateOption('Logo', inputs.Logo);
try {
setLoadingInput((loadingInput) => ({ ...loadingInput, Logo: true }));
await updateOption('Logo', inputs.Logo);
showSuccess('Logo 已更新');
} catch (error) {
console.error('Logo 更新失败', error);
showError('Logo 更新失败');
} finally {
setLoadingInput((loadingInput) => ({ ...loadingInput, Logo: false }));
}
};
const submitAbout = async () => {
await updateOption('About', inputs.About);
};
// 个性化设置 - 首页内容
const submitOption = async (key) => {
await updateOption(key, inputs[key]);
try {
setLoadingInput((loadingInput) => ({ ...loadingInput, HomePageContent: true }));
await updateOption(key, inputs[key]);
showSuccess('首页内容已更新');
} catch (error) {
console.error('首页内容更新失败', error);
showError('首页内容更新失败');
} finally {
setLoadingInput((loadingInput) => ({ ...loadingInput, HomePageContent: false }));
}
};
// 个性化设置 - 关于
const submitAbout = async () => {
try {
setLoadingInput((loadingInput) => ({ ...loadingInput, About: true }));
await updateOption('About', inputs.About);
showSuccess('关于内容已更新');
} catch (error) {
console.error('关于内容更新失败', error);
showError('关于内容更新失败');
} finally {
setLoadingInput((loadingInput) => ({ ...loadingInput, About: false }));
}
};
// 个性化设置 - 页脚
const submitFooter = async () => {
try {
setLoadingInput((loadingInput) => ({ ...loadingInput, Footer: true }));
await updateOption('Footer', inputs.Footer);
showSuccess('页脚内容已更新');
} catch (error) {
console.error('页脚内容更新失败', error);
showError('页脚内容更新失败');
} finally {
setLoadingInput((loadingInput) => ({ ...loadingInput, Footer: false }));
}
};
const openGitHubRelease = () => {
window.location =
@@ -102,82 +153,102 @@ const OtherSetting = () => {
setShowUpdateModal(true);
}
};
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key in inputs) {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
formAPISettingGeneral.current.setValues(newInputs);
formAPIPersonalization.current.setValues(newInputs);
} else {
showError(message);
}
};
useEffect(() => {
getOptions();
}, []);
return (
<Grid columns={1}>
<Grid.Column>
<Form loading={loading}>
<Header as='h3'>通用设置</Header>
{/*<Form.Button onClick={checkUpdate}>检查更新</Form.Button>*/}
<Form.Group widths='equal'>
<Row>
<Col span={24}>
{/* 通用设置 */}
<Form values={inputs} getFormApi={formAPI => formAPISettingGeneral.current = formAPI}
style={{ marginBottom: 15 }}>
<Form.Section text={'通用设置'}>
<Form.TextArea
label='公告'
placeholder='在此输入新的公告内容,支持 Markdown & HTML 代码'
value={inputs.Notice}
name='Notice'
label={'公告'}
placeholder={'在此输入新的公告内容,支持 Markdown & HTML 代码'}
field={'Notice'}
onChange={handleInputChange}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
style={{ fontFamily: 'JetBrains Mono, Consolas' }}
autosize={{ minRows: 6, maxRows: 12 }}
/>
</Form.Group>
<Form.Button onClick={submitNotice}>保存公告</Form.Button>
<Divider />
<Header as='h3'>个性化设置</Header>
<Form.Group widths='equal'>
<Form.Input
label='系统名称'
placeholder='在此输入系统名称'
value={inputs.SystemName}
name='SystemName'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={submitSystemName}>设置系统名称</Form.Button>
<Form.Group widths='equal'>
<Form.Input
label='Logo 图片地址'
placeholder='在此输入 Logo 图片地址'
value={inputs.Logo}
name='Logo'
type='url'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={submitLogo}>设置 Logo</Form.Button>
<Form.Group widths='equal'>
<Form.TextArea
label='首页内容'
placeholder='在此输入首页内容,支持 Markdown & HTML 代码,设置后首页的状态信息将不再显示。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为首页。'
value={inputs.HomePageContent}
name='HomePageContent'
onChange={handleInputChange}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
/>
</Form.Group>
<Form.Button onClick={() => submitOption('HomePageContent')}>保存首页内容</Form.Button>
<Form.Group widths='equal'>
<Form.TextArea
label='关于'
placeholder='在此输入新的关于内容,支持 Markdown & HTML 代码。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为关于页面。'
value={inputs.About}
name='About'
onChange={handleInputChange}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
/>
</Form.Group>
<Form.Button onClick={submitAbout}>保存关于</Form.Button>
<Message>移除 One API 的版权标识必须首先获得授权项目维护需要花费大量精力如果本项目对你有意义请主动支持本项目</Message>
<Form.Group widths='equal'>
<Form.Input
label='页脚'
placeholder='在此输入新的页脚,留空则使用默认页脚,支持 HTML 代码'
value={inputs.Footer}
name='Footer'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={submitFooter}>设置页脚</Form.Button>
<Button onClick={submitNotice} loading={loadingInput['Notice']}>设置公告</Button>
</Form.Section>
</Form>
</Grid.Column>
{/* 个性化设置 */}
<Form values={inputs} getFormApi={formAPI => formAPIPersonalization.current = formAPI}
style={{ marginBottom: 15 }}>
<Form.Section text={'个性化设置'}>
<Form.Input
label={'系统名称'}
placeholder={'在此输入系统名称'}
field={'SystemName'}
onChange={handleInputChange}
/>
<Button onClick={submitSystemName} loading={loadingInput['SystemName']}>设置系统名称</Button>
<Form.Input
label={'Logo 图片地址'}
placeholder={'在此输入 Logo 图片地址'}
field={'Logo'}
onChange={handleInputChange}
/>
<Button onClick={submitLogo} loading={loadingInput['Logo']}>设置 Logo</Button>
<Form.TextArea
label={'首页内容'}
placeholder={'在此输入首页内容,支持 Markdown & HTML 代码,设置后首页的状态信息将不再显示。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为首页。'}
field={'HomePageContent'}
onChange={handleInputChange}
style={{ fontFamily: 'JetBrains Mono, Consolas' }}
autosize={{ minRows: 6, maxRows: 12 }}
/>
<Button onClick={() => submitOption('HomePageContent')}
loading={loadingInput['HomePageContent']}>设置首页内容</Button>
<Form.TextArea
label={'关于'}
placeholder={'在此输入新的关于内容,支持 Markdown & HTML 代码。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为关于页面。'}
field={'About'}
onChange={handleInputChange}
style={{ fontFamily: 'JetBrains Mono, Consolas' }}
autosize={{ minRows: 6, maxRows: 12 }}
/>
<Button onClick={submitAbout} loading={loadingInput['About']}>设置关于</Button>
{/* */}
<Banner
fullMode={false}
type="info"
description="移除 One API 的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目。"
closeIcon={null}
style={{ marginTop: 15 }}
/>
<Form.Input
label={'页脚'}
placeholder={'在此输入新的页脚,留空则使用默认页脚,支持 HTML 代码'}
field={'Footer'}
onChange={handleInputChange}
/>
<Button onClick={submitFooter} loading={loadingInput['Footer']}>设置页脚</Button>
</Form.Section>
</Form>
</Col>
{/*<Modal*/}
{/* onClose={() => setShowUpdateModal(false)}*/}
{/* onOpen={() => setShowUpdateModal(true)}*/}
@@ -200,7 +271,7 @@ const OtherSetting = () => {
{/* />*/}
{/* </Modal.Actions>*/}
{/*</Modal>*/}
</Grid>
</Row>
);
};

View File

@@ -1,12 +1,12 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Grid, Header, Image, Segment } from 'semantic-ui-react';
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
import { API, copy, showError, showNotice } from '../helpers';
import { useSearchParams } from 'react-router-dom';
const PasswordResetConfirm = () => {
const [inputs, setInputs] = useState({
email: '',
token: '',
token: ''
});
const { email, token } = inputs;
@@ -23,7 +23,7 @@ const PasswordResetConfirm = () => {
let email = searchParams.get('email');
setInputs({
token,
email,
email
});
}, []);
@@ -37,7 +37,7 @@ const PasswordResetConfirm = () => {
setDisableButton(false);
setCountdown(30);
}
return () => clearInterval(countdownInterval);
return () => clearInterval(countdownInterval);
}, [disableButton, countdown]);
async function handleSubmit(e) {
@@ -46,7 +46,7 @@ const PasswordResetConfirm = () => {
setLoading(true);
const res = await API.post(`/api/user/reset`, {
email,
token,
token
});
const { success, message } = res.data;
if (success) {
@@ -59,44 +59,44 @@ const PasswordResetConfirm = () => {
}
setLoading(false);
}
return (
<Grid textAlign='center' style={{ marginTop: '48px' }}>
<Grid textAlign="center" style={{ marginTop: '48px' }}>
<Grid.Column style={{ maxWidth: 450 }}>
<Header as='h2' color='' textAlign='center'>
<Image src='/logo.png' /> 密码重置确认
<Header as="h2" color="" textAlign="center">
<Image src="/logo.png" /> 密码重置确认
</Header>
<Form size='large'>
<Form size="large">
<Segment>
<Form.Input
fluid
icon='mail'
iconPosition='left'
placeholder='邮箱地址'
name='email'
icon="mail"
iconPosition="left"
placeholder="邮箱地址"
name="email"
value={email}
readOnly
/>
{newPassword && (
<Form.Input
fluid
icon='lock'
iconPosition='left'
placeholder='新密码'
name='newPassword'
value={newPassword}
readOnly
onClick={(e) => {
e.target.select();
navigator.clipboard.writeText(newPassword);
showNotice(`密码已复制到剪贴板:${newPassword}`);
}}
/>
fluid
icon="lock"
iconPosition="left"
placeholder="新密码"
name="newPassword"
value={newPassword}
readOnly
onClick={(e) => {
e.target.select();
navigator.clipboard.writeText(newPassword);
showNotice(`密码已复制到剪贴板:${newPassword}`);
}}
/>
)}
<Button
color='green'
color="green"
fluid
size='large'
size="large"
onClick={handleSubmit}
loading={loading}
disabled={disableButton}
@@ -107,7 +107,7 @@ const PasswordResetConfirm = () => {
</Form>
</Grid.Column>
</Grid>
);
);
};
export default PasswordResetConfirm;

View File

@@ -56,19 +56,19 @@ const PasswordResetForm = () => {
}
return (
<Grid textAlign='center' style={{ marginTop: '48px' }}>
<Grid textAlign="center" style={{ marginTop: '48px' }}>
<Grid.Column style={{ maxWidth: 450 }}>
<Header as='h2' color='' textAlign='center'>
<Image src='/logo.png' /> 密码重置
<Header as="h2" color="" textAlign="center">
<Image src="/logo.png" /> 密码重置
</Header>
<Form size='large'>
<Form size="large">
<Segment>
<Form.Input
fluid
icon='mail'
iconPosition='left'
placeholder='邮箱地址'
name='email'
icon="mail"
iconPosition="left"
placeholder="邮箱地址"
name="email"
value={email}
onChange={handleChange}
/>
@@ -83,9 +83,9 @@ const PasswordResetForm = () => {
<></>
)}
<Button
color='green'
color="green"
fluid
size='large'
size="large"
onClick={handleSubmit}
loading={loading}
disabled={disableButton}

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@ import { history } from '../helpers';
function PrivateRoute({ children }) {
if (!localStorage.getItem('user')) {
return <Navigate to='/login' state={{ from: history.location }} />;
return <Navigate to="/login" state={{ from: history.location }} />;
}
return children;
}

Some files were not shown because too many files have changed in this diff Show More