mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-18 11:33:42 +08:00
Compare commits
94 Commits
v0.2.0
...
v0.2.3-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b55116563 | ||
|
|
f35e63e3f3 | ||
|
|
17c409de23 | ||
|
|
e4753e7411 | ||
|
|
bec21ade9d | ||
|
|
4c4e087060 | ||
|
|
beadb98a8c | ||
|
|
615d109d70 | ||
|
|
0da49fa446 | ||
|
|
578b5f6536 | ||
|
|
f63ad9c03c | ||
|
|
0fb98e44a7 | ||
|
|
652bb4a53c | ||
|
|
a26b9a9bff | ||
|
|
f82aa956bd | ||
|
|
b8c053c37f | ||
|
|
2581b37394 | ||
|
|
f65477d054 | ||
|
|
4b952b8582 | ||
|
|
a6ba1d01d9 | ||
|
|
2d2fec24d0 | ||
|
|
d34b55c154 | ||
|
|
25ec99913b | ||
|
|
7f22d58574 | ||
|
|
dfdeadf1a5 | ||
|
|
9adefa80b9 | ||
|
|
6ab1b3a524 | ||
|
|
4917e5a92f | ||
|
|
4ce2381182 | ||
|
|
62afc21ea5 | ||
|
|
7ddb7c586d | ||
|
|
f62dcbf669 | ||
|
|
299911d4cd | ||
|
|
84e0544604 | ||
|
|
2786a6b539 | ||
|
|
9b5353a81a | ||
|
|
bc5a54df59 | ||
|
|
d704902b70 | ||
|
|
614220a0fb | ||
|
|
98c1f66d61 | ||
|
|
a77fbc0fa2 | ||
|
|
44361d75e8 | ||
|
|
9b2e5c2978 | ||
|
|
d3399d68f6 | ||
|
|
3d10c9f090 | ||
|
|
d5ffaf2502 | ||
|
|
2ad591411e | ||
|
|
728dbed28d | ||
|
|
fd3a41bacb | ||
|
|
37c0c8ebdd | ||
|
|
95d8059c90 | ||
|
|
c012306400 | ||
|
|
1e1a53e4d2 | ||
|
|
30d48ea473 | ||
|
|
9e5c636490 | ||
|
|
d53d3386e9 | ||
|
|
c3a01decd8 | ||
|
|
c29680b301 | ||
|
|
ac7407ce9c | ||
|
|
3ea061a820 | ||
|
|
d4df1960b2 | ||
|
|
97dd80541b | ||
|
|
bf241b218f | ||
|
|
7ab6c6c303 | ||
|
|
8fe8340b6e | ||
|
|
26ef906c61 | ||
|
|
655dfe0d09 | ||
|
|
f43b268520 | ||
|
|
37113c0e96 | ||
|
|
3c3c53051d | ||
|
|
6a83c8ad86 | ||
|
|
0640dd81fd | ||
|
|
cb50fcaffe | ||
|
|
eca48268b2 | ||
|
|
4a0af1ea3c | ||
|
|
c2965eb835 | ||
|
|
4186880e4c | ||
|
|
280c63e1d4 | ||
|
|
9e0a54e943 | ||
|
|
0e06be8c3e | ||
|
|
931d22c96f | ||
|
|
6413bf342a | ||
|
|
626217fbd4 | ||
|
|
9de2d21e1a | ||
|
|
3ab4f145db | ||
|
|
da73dca9a7 | ||
|
|
415d296171 | ||
|
|
d5c5c30312 | ||
|
|
fb39b6b30e | ||
|
|
92c1ed7f1d | ||
|
|
d160736a49 | ||
|
|
fe7f42fc2e | ||
|
|
5cb933a278 | ||
|
|
ac64fd26ad |
6
.github/ISSUE_TEMPLATE/config.yml
vendored
6
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: 项目群聊
|
||||
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
|
||||
about: QQ 群:629454374
|
||||
- name: 交流社区
|
||||
url: https://linux.do
|
||||
about: 项目交流社区
|
||||
|
||||
2
.github/workflows/docker-image-amd64.yml
vendored
2
.github/workflows/docker-image-amd64.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
pengzhile/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
|
||||
2
.github/workflows/docker-image-arm64.yml
vendored
2
.github/workflows/docker-image-arm64.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
pengzhile/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
|
||||
23
Dockerfile
23
Dockerfile
@@ -1,32 +1,29 @@
|
||||
FROM node:16 as builder
|
||||
FROM node:16-slim as builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
RUN npm install
|
||||
COPY web/yarn.lock .
|
||||
RUN yarn install
|
||||
COPY ./web .
|
||||
COPY ./VERSION .
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||
|
||||
FROM golang AS builder2
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) yarn build
|
||||
|
||||
FROM golang:1.19-alpine AS builder2
|
||||
RUN apk add --no-cache build-base
|
||||
ENV GO111MODULE=on \
|
||||
CGO_ENABLED=1 \
|
||||
GOOS=linux
|
||||
|
||||
WORKDIR /build
|
||||
ADD go.mod go.sum ./
|
||||
RUN go mod download
|
||||
#ADD go.mod go.sum ./
|
||||
COPY . .
|
||||
COPY --from=builder /build/build ./web/build
|
||||
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
|
||||
|
||||
RUN go mod tidy \
|
||||
&& go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
|
||||
|
||||
FROM alpine
|
||||
|
||||
RUN apk update \
|
||||
&& apk upgrade \
|
||||
&& apk add --no-cache ca-certificates tzdata \
|
||||
&& update-ca-certificates 2>/dev/null || true
|
||||
|
||||
COPY --from=builder2 /build/one-api /
|
||||
EXPOSE 3000
|
||||
WORKDIR /data
|
||||
|
||||
@@ -7,7 +7,7 @@ all: build-frontend start-backend
|
||||
|
||||
build-frontend:
|
||||
@echo "Building frontend..."
|
||||
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build npm run build
|
||||
@cd $(FRONTEND_DIR) && yarn install && DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) yarn build
|
||||
|
||||
start-backend:
|
||||
@echo "Starting backend dev server..."
|
||||
314
Midjourney.md
314
Midjourney.md
@@ -2,290 +2,66 @@
|
||||
|
||||
**简介**:Midjourney Proxy API文档
|
||||
|
||||
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
|
||||
## 模型列表
|
||||
|
||||
### midjourney-proxy支持
|
||||
|
||||
- mj_imagine (绘图)
|
||||
- mj_variation (变换)
|
||||
- mj_reroll (重绘)
|
||||
- mj_blend (混合)
|
||||
- mj_upscale (放大)
|
||||
- mj_describe (图生文)
|
||||
|
||||
### 仅midjourney-proxy-plus支持
|
||||
|
||||
- mj_zoom (比例变焦)
|
||||
- mj_shorten (提示词缩短)
|
||||
- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
|
||||
- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
|
||||
- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
|
||||
- mj_high_variation (强变换)
|
||||
- mj_low_variation (弱变换)
|
||||
- mj_pan (平移)
|
||||
- swap_face (换脸)
|
||||
|
||||
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
|
||||
```json
|
||||
{
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_modal": 0.1,
|
||||
"mj_zoom": 0.1,
|
||||
"mj_shorten": 0.1,
|
||||
"mj_high_variation": 0.1,
|
||||
"mj_low_variation": 0.1,
|
||||
"mj_pan": 0.1,
|
||||
"mj_inpaint": 0,
|
||||
"mj_custom_zoom": 0,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05
|
||||
"mj_upscale": 0.05,
|
||||
"swap_face": 0.05
|
||||
}
|
||||
```
|
||||
其中mj_inpaint和mj_custom_zoom的价格设置为0,是因为这两个模型需要搭配mj_modal使用,所以价格由mj_modal决定。
|
||||
|
||||
## 渠道设置
|
||||
|
||||
### 对接 midjourney-proxy
|
||||
1. 部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
|
||||
2. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
|
||||
### 对接 midjourney-proxy(plus)
|
||||
|
||||
1.
|
||||
|
||||
部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
|
||||
|
||||
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
|
||||
,模型请参考上方模型列表
|
||||
3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
|
||||
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
|
||||
|
||||
### 对接上游new api
|
||||
1. 在渠道管理中添加渠道,渠道类型选择Midjourney Proxy,模型选择midjourney
|
||||
2. 地址填写上游new api的地址,例如:http://localhost:8080
|
||||
3. 密钥填写上游new api的密钥
|
||||
|
||||
## 任务提交
|
||||
|
||||
### 绘图变化
|
||||
|
||||
**接口地址**:`/mj/submit/change`
|
||||
|
||||
**请求方式**:`POST`
|
||||
|
||||
**请求数据类型**:`application/json`
|
||||
|
||||
**响应数据类型**:`*/*`
|
||||
|
||||
**接口描述**:
|
||||
|
||||
**请求示例**:
|
||||
|
||||
```javascript
|
||||
{
|
||||
"action"
|
||||
:
|
||||
"UPSCALE",
|
||||
"index"
|
||||
:
|
||||
1,
|
||||
"notifyHook"
|
||||
:
|
||||
"",
|
||||
"state"
|
||||
:
|
||||
"",
|
||||
"taskId"
|
||||
:
|
||||
"1320098173412546"
|
||||
}
|
||||
```
|
||||
|
||||
**请求参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|
||||
|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------|
|
||||
| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 |
|
||||
|   action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | |
|
||||
|   index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | |
|
||||
|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
|
||||
|   state | 自定义参数 | | false | string | |
|
||||
|   taskId | 任务ID | | true | string | |
|
||||
|
||||
**响应状态**:
|
||||
|
||||
| 状态码 | 说明 | schema |
|
||||
|-----|--------------|--------|
|
||||
| 200 | OK | 提交结果 |
|
||||
| 201 | Created | |
|
||||
| 401 | Unauthorized | |
|
||||
| 403 | Forbidden | |
|
||||
| 404 | Not Found | |
|
||||
|
||||
**响应参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 类型 | schema |
|
||||
|-------------|-------------------------------------------|----------------|----------------|
|
||||
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
|
||||
| description | 描述 | string | |
|
||||
| properties | 扩展字段 | object | |
|
||||
| result | 任务ID | string | |
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```javascript
|
||||
{
|
||||
"code"
|
||||
:
|
||||
1,
|
||||
"description"
|
||||
:
|
||||
"提交成功",
|
||||
"properties"
|
||||
:
|
||||
{
|
||||
}
|
||||
,
|
||||
"result"
|
||||
:
|
||||
1320098173412546
|
||||
}
|
||||
```
|
||||
|
||||
### 提交Imagine任务
|
||||
|
||||
**接口地址**:`/mj/submit/imagine`
|
||||
|
||||
**请求方式**:`POST`
|
||||
|
||||
**请求数据类型**:`application/json`
|
||||
|
||||
**响应数据类型**:`*/*`
|
||||
|
||||
**接口描述**:
|
||||
|
||||
**请求示例**:
|
||||
|
||||
```javascript
|
||||
{
|
||||
"base64"
|
||||
:
|
||||
"",
|
||||
"notifyHook"
|
||||
:
|
||||
"",
|
||||
"prompt"
|
||||
:
|
||||
"Cat",
|
||||
"state"
|
||||
:
|
||||
""
|
||||
}
|
||||
```
|
||||
|
||||
**请求参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|
||||
|------------------------|-------------------------|------|-------|-------------|-------------|
|
||||
| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 |
|
||||
|   base64 | 垫图base64 | | false | string | |
|
||||
|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
|
||||
|   prompt | 提示词 | | true | string | |
|
||||
|   state | 自定义参数 | | false | string | |
|
||||
|
||||
**响应状态**:
|
||||
|
||||
| 状态码 | 说明 | schema |
|
||||
|-----|--------------|--------|
|
||||
| 200 | OK | 提交结果 |
|
||||
| 201 | Created | |
|
||||
| 401 | Unauthorized | |
|
||||
| 403 | Forbidden | |
|
||||
| 404 | Not Found | |
|
||||
|
||||
**响应参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 类型 | schema |
|
||||
|-------------|-------------------------------------------|----------------|----------------|
|
||||
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
|
||||
| description | 描述 | string | |
|
||||
| properties | 扩展字段 | object | |
|
||||
| result | 任务ID | string | |
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```javascript
|
||||
{
|
||||
"code"
|
||||
:
|
||||
1,
|
||||
"description"
|
||||
:
|
||||
"提交成功",
|
||||
"properties"
|
||||
:
|
||||
{
|
||||
}
|
||||
,
|
||||
"result"
|
||||
:
|
||||
1320098173412546
|
||||
}
|
||||
```
|
||||
|
||||
## 任务查询
|
||||
|
||||
### 指定ID获取任务
|
||||
|
||||
**接口地址**:`/mj/task/{id}/fetch`
|
||||
|
||||
**请求方式**:`GET`
|
||||
|
||||
**请求数据类型**:`application/x-www-form-urlencoded`
|
||||
|
||||
**响应数据类型**:`*/*`
|
||||
|
||||
**接口描述**:
|
||||
|
||||
**请求参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|
||||
|------|------|------|-------|--------|--------|
|
||||
| id | 任务ID | path | false | string | |
|
||||
|
||||
**响应状态**:
|
||||
|
||||
| 状态码 | 说明 | schema |
|
||||
|-----|--------------|--------|
|
||||
| 200 | OK | 任务 |
|
||||
| 401 | Unauthorized | |
|
||||
| 403 | Forbidden | |
|
||||
| 404 | Not Found | |
|
||||
|
||||
**响应参数**:
|
||||
|
||||
| 参数名称 | 参数说明 | 类型 | schema |
|
||||
|-------------|----------------------------------------------------------|----------------|----------------|
|
||||
| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | |
|
||||
| description | 任务描述 | string | |
|
||||
| failReason | 失败原因 | string | |
|
||||
| finishTime | 结束时间 | integer(int64) | integer(int64) |
|
||||
| id | 任务ID | string | |
|
||||
| imageUrl | 图片url | string | |
|
||||
| progress | 任务进度 | string | |
|
||||
| prompt | 提示词 | string | |
|
||||
| promptEn | 提示词-英文 | string | |
|
||||
| startTime | 开始执行时间 | integer(int64) | integer(int64) |
|
||||
| state | 自定义参数 | string | |
|
||||
| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | |
|
||||
| submitTime | 提交时间 | integer(int64) | integer(int64) |
|
||||
|
||||
**响应示例**:
|
||||
|
||||
```javascript
|
||||
{
|
||||
"action"
|
||||
:
|
||||
"",
|
||||
"description"
|
||||
:
|
||||
"",
|
||||
"failReason"
|
||||
:
|
||||
"",
|
||||
"finishTime"
|
||||
:
|
||||
0,
|
||||
"id"
|
||||
:
|
||||
"",
|
||||
"imageUrl"
|
||||
:
|
||||
"",
|
||||
"progress"
|
||||
:
|
||||
"",
|
||||
"prompt"
|
||||
:
|
||||
"",
|
||||
"promptEn"
|
||||
:
|
||||
"",
|
||||
"startTime"
|
||||
:
|
||||
0,
|
||||
"state"
|
||||
:
|
||||
"",
|
||||
"status"
|
||||
:
|
||||
"",
|
||||
"submitTime"
|
||||
:
|
||||
0
|
||||
}
|
||||
```
|
||||
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
|
||||
2. 地址填写上游new api的地址,例如:http://localhost:3000
|
||||
3. 密钥填写上游new api的密钥
|
||||
22
README.md
22
README.md
@@ -18,7 +18,7 @@
|
||||
此分叉版本的主要变更如下:
|
||||
|
||||
1. 全新的UI界面(部分界面还待更新)
|
||||
2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持
|
||||
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md),支持的接口如下:
|
||||
+ [x] /mj/submit/imagine
|
||||
+ [x] /mj/submit/change
|
||||
+ [x] /mj/submit/blend
|
||||
@@ -26,6 +26,11 @@
|
||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||
+ [x] /task/list-by-condition
|
||||
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||
+ [x] /mj/submit/modal
|
||||
+ [x] /mj/submit/shorten
|
||||
+ [x] /mj/task/{id}/image-seed
|
||||
+ [x] /mj/insight-face/swap (InsightFace)
|
||||
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
||||
+ [x] 易支付
|
||||
4. 支持用key查询使用额度:
|
||||
@@ -37,12 +42,19 @@
|
||||
9. 支持渠道**加权随机**
|
||||
10. 数据看板
|
||||
11. 可设置令牌能调用的模型
|
||||
12. 支持Telegram授权登录
|
||||
12. 支持Telegram授权登录。
|
||||
1. 系统设置-配置登录注册-允许通过Telegram登录
|
||||
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
||||
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||
|
||||
## 模型支持
|
||||
此版本额外支持以下模型:
|
||||
1. 第三方模型 **gps** (gpt-4-gizmo-*)
|
||||
2. 智谱glm-4v,glm-4v识图
|
||||
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
|
||||
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
||||
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
||||
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
|
||||
@@ -86,6 +98,12 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||

|
||||

|
||||
|
||||
## 相关项目
|
||||
- [One API](https://github.com/songquanpeng/one-api):原版项目
|
||||
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy):Midjourney接口支持
|
||||
- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代 AI 一站式 B/C 端解决方案
|
||||
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool):用key查询使用额度
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
|
||||
@@ -50,6 +50,7 @@ var PasswordLoginEnabled = true
|
||||
var PasswordRegisterEnabled = true
|
||||
var EmailVerificationEnabled = false
|
||||
var GitHubOAuthEnabled = false
|
||||
var LinuxDoOAuthEnabled = false
|
||||
var WeChatAuthEnabled = false
|
||||
var TelegramOAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
@@ -82,6 +83,10 @@ var SMTPToken = ""
|
||||
var GitHubClientId = ""
|
||||
var GitHubClientSecret = ""
|
||||
|
||||
var LinuxDoClientId = ""
|
||||
var LinuxDoClientSecret = ""
|
||||
var LinuxDoMinLevel = 0
|
||||
|
||||
var WeChatServerAddress = ""
|
||||
var WeChatServerToken = ""
|
||||
var WeChatAccountQRCodeImageURL = ""
|
||||
@@ -186,10 +191,10 @@ const (
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeAPI2D = 2
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeCloseAI = 4
|
||||
ChannelTypeOpenAISB = 5
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
@@ -211,6 +216,7 @@ const (
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
@@ -218,7 +224,7 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"https://api.closeai-proxy.xyz", // 4
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
@@ -238,7 +244,8 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.cloud.tencent.com", //23
|
||||
"", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
}
|
||||
|
||||
@@ -2,5 +2,6 @@ package common
|
||||
|
||||
var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
var UsingMySQL = false
|
||||
|
||||
var SQLitePath = "one-api.db?_busy_timeout=5000"
|
||||
|
||||
@@ -5,14 +5,14 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/chai2010/webp"
|
||||
"golang.org/x/image/webp"
|
||||
"image"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
||||
// 去除base64数据的URL前缀(如果有)
|
||||
if idx := strings.Index(base64String, ","); idx != -1 {
|
||||
base64String = base64String[idx+1:]
|
||||
@@ -22,13 +22,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64String)
|
||||
if err != nil {
|
||||
fmt.Println("Error: Failed to decode base64 string")
|
||||
return image.Config{}, "", err
|
||||
return image.Config{}, "", "", err
|
||||
}
|
||||
|
||||
// 创建一个bytes.Buffer用于存储解码后的数据
|
||||
reader := bytes.NewReader(decodedData)
|
||||
config, format, err := getImageConfig(reader)
|
||||
return config, format, err
|
||||
return config, format, base64String, err
|
||||
}
|
||||
|
||||
func IsImageUrl(url string) (bool, error) {
|
||||
@@ -42,6 +42,7 @@ func IsImageUrl(url string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
isImage, err := IsImageUrl(url)
|
||||
if !isImage {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
// TODO: when a new api is enabled, check the pricing here
|
||||
// 1 === $0.002 / 1K tokens
|
||||
// 1 === ¥0.014 / 1k tokens
|
||||
var ModelRatio = map[string]float64{
|
||||
var DefaultModelRatio = map[string]float64{
|
||||
//"midjourney": 50,
|
||||
"gpt-4-gizmo-*": 15,
|
||||
"gpt-4": 15,
|
||||
@@ -61,8 +61,12 @@ var ModelRatio = map[string]float64{
|
||||
"text-moderation-latest": 0.1,
|
||||
"dall-e-2": 8,
|
||||
"dall-e-3": 16,
|
||||
"claude-instant-1": 0.815, // $1.63 / 1M tokens
|
||||
"claude-2": 5.51, // $11.02 / 1M tokens
|
||||
"claude-instant-1": 0.4, // $0.8 / 1M tokens
|
||||
"claude-2.0": 4, // $8 / 1M tokens
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||
@@ -91,17 +95,32 @@ var ModelRatio = map[string]float64{
|
||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||
}
|
||||
|
||||
var ModelPrice = map[string]float64{
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
var DefaultModelPrice = map[string]float64{
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_modal": 0.1,
|
||||
"mj_zoom": 0.1,
|
||||
"mj_shorten": 0.1,
|
||||
"mj_high_variation": 0.1,
|
||||
"mj_low_variation": 0.1,
|
||||
"mj_pan": 0.1,
|
||||
"mj_inpaint": 0,
|
||||
"mj_custom_zoom": 0,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
"swap_face": 0.05,
|
||||
}
|
||||
|
||||
var ModelPrice = map[string]float64{}
|
||||
var ModelRatio = map[string]float64{}
|
||||
|
||||
func ModelPrice2JSONString() string {
|
||||
if len(ModelPrice) == 0 {
|
||||
ModelPrice = DefaultModelPrice
|
||||
}
|
||||
jsonBytes, err := json.Marshal(ModelPrice)
|
||||
if err != nil {
|
||||
SysError("error marshalling model price: " + err.Error())
|
||||
@@ -115,6 +134,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
|
||||
}
|
||||
|
||||
func GetModelPrice(name string, printErr bool) float64 {
|
||||
if len(ModelPrice) == 0 {
|
||||
ModelPrice = DefaultModelPrice
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
@@ -129,6 +151,9 @@ func GetModelPrice(name string, printErr bool) float64 {
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
if len(ModelRatio) == 0 {
|
||||
ModelRatio = DefaultModelRatio
|
||||
}
|
||||
jsonBytes, err := json.Marshal(ModelRatio)
|
||||
if err != nil {
|
||||
SysError("error marshalling model ratio: " + err.Error())
|
||||
@@ -142,6 +167,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
|
||||
}
|
||||
|
||||
func GetModelRatio(name string) float64 {
|
||||
if len(ModelRatio) == 0 {
|
||||
ModelRatio = DefaultModelRatio
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
@@ -179,10 +207,11 @@ func GetCompletionRatio(name string) float64 {
|
||||
return 2
|
||||
}
|
||||
if strings.HasPrefix(name, "claude-instant-1") {
|
||||
return 3.38
|
||||
}
|
||||
if strings.HasPrefix(name, "claude-2") {
|
||||
return 2.965517
|
||||
return 3
|
||||
} else if strings.HasPrefix(name, "claude-2") {
|
||||
return 3
|
||||
} else if strings.HasPrefix(name, "claude-3") {
|
||||
return 5
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
44
constant/midjourney.go
Normal file
44
constant/midjourney.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package constant
|
||||
|
||||
var MjNotifyEnabled = false
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
MjRequestError = 4
|
||||
)
|
||||
|
||||
const (
|
||||
MjActionImagine = "IMAGINE"
|
||||
MjActionDescribe = "DESCRIBE"
|
||||
MjActionBlend = "BLEND"
|
||||
MjActionUpscale = "UPSCALE"
|
||||
MjActionVariation = "VARIATION"
|
||||
MjActionReRoll = "REROLL"
|
||||
MjActionInPaint = "INPAINT"
|
||||
MjActionModal = "MODAL"
|
||||
MjActionZoom = "ZOOM"
|
||||
MjActionCustomZoom = "CUSTOM_ZOOM"
|
||||
MjActionShorten = "SHORTEN"
|
||||
MjActionHighVariation = "HIGH_VARIATION"
|
||||
MjActionLowVariation = "LOW_VARIATION"
|
||||
MjActionPan = "PAN"
|
||||
MjActionSwapFace = "SWAP_FACE"
|
||||
)
|
||||
|
||||
var MidjourneyModel2Action = map[string]string{
|
||||
"mj_imagine": MjActionImagine,
|
||||
"mj_describe": MjActionDescribe,
|
||||
"mj_blend": MjActionBlend,
|
||||
"mj_upscale": MjActionUpscale,
|
||||
"mj_variation": MjActionVariation,
|
||||
"mj_reroll": MjActionReRoll,
|
||||
"mj_modal": MjActionModal,
|
||||
"mj_inpaint": MjActionInPaint,
|
||||
"mj_zoom": MjActionZoom,
|
||||
"mj_custom_zoom": MjActionCustomZoom,
|
||||
"mj_shorten": MjActionShorten,
|
||||
"mj_high_variation": MjActionHighVariation,
|
||||
"mj_low_variation": MjActionLowVariation,
|
||||
"mj_pan": MjActionPan,
|
||||
"swap_face": MjActionSwapFace,
|
||||
}
|
||||
@@ -214,10 +214,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
return 0, errors.New("尚未实现")
|
||||
case common.ChannelTypeCustom:
|
||||
baseURL = channel.GetBaseURL()
|
||||
case common.ChannelTypeCloseAI:
|
||||
return updateChannelCloseAIBalance(channel)
|
||||
case common.ChannelTypeOpenAISB:
|
||||
return updateChannelOpenAISBBalance(channel)
|
||||
//case common.ChannelTypeOpenAISB:
|
||||
// return updateChannelOpenAISBBalance(channel)
|
||||
case common.ChannelTypeAIProxy:
|
||||
return updateChannelAIProxyBalance(channel)
|
||||
case common.ChannelTypeAPI2GPT:
|
||||
|
||||
@@ -24,6 +24,9 @@ import (
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
}
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -37,6 +40,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
apiType := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
@@ -45,13 +61,14 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
}
|
||||
if testModel == "" {
|
||||
testModel = adaptor.GetModelList()[0]
|
||||
meta.UpstreamModelName = testModel
|
||||
}
|
||||
request := buildTestRequest()
|
||||
request.Model = testModel
|
||||
meta.UpstreamModelName = testModel
|
||||
|
||||
adaptor.Init(meta, *request)
|
||||
|
||||
request.Model = testModel
|
||||
meta.UpstreamModelName = testModel
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
@@ -68,11 +85,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err := relaycommon.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
||||
}
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError
|
||||
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
||||
}
|
||||
if usage == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
|
||||
@@ -123,6 +123,8 @@ func GitHubOAuth(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
|
||||
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if githubUser.Name != "" {
|
||||
user.DisplayName = githubUser.Name
|
||||
@@ -133,7 +135,7 @@ func GitHubOAuth(c *gin.Context) {
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(0); err != nil {
|
||||
if err := user.Insert(user.InviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
|
||||
239
controller/linuxdo.go
Normal file
239
controller/linuxdo.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LinuxDoOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type LinuxDoUser struct {
|
||||
ID int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Authorization", "Basic "+auth)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse LinuxDoOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
var linuxdoUser LinuxDoUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if linuxdoUser.ID == 0 {
|
||||
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||
}
|
||||
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
|
||||
return nil, errors.New("用户 LINUX DO 信任等级不足!")
|
||||
}
|
||||
return &linuxdoUser, nil
|
||||
}
|
||||
|
||||
func LinuxDoOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LinuxDoBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.LinuxDoOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 LINUX DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
|
||||
LinuxDoLevel: linuxdoUser.TrustLevel,
|
||||
}
|
||||
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
|
||||
err := user.FillUserByLinuxDoId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDoLevel = linuxdoUser.TrustLevel
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
affCode := c.Query("aff")
|
||||
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
|
||||
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if linuxdoUser.Name != "" {
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
} else {
|
||||
user.DisplayName = linuxdoUser.Username
|
||||
}
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(user.InviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func LinuxDoBind(c *gin.Context) {
|
||||
if !common.LinuxDoOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 LINUX DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
|
||||
LinuxDoLevel: linuxdoUser.TrustLevel,
|
||||
}
|
||||
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 LINUX DO 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
|
||||
user.LinuxDoLevel = linuxdoUser.TrustLevel
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -10,9 +10,13 @@ import (
|
||||
|
||||
func GetAllLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
@@ -20,7 +24,7 @@ func GetAllLogs(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
|
||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -38,16 +42,23 @@ func GetAllLogs(c *gin.Context) {
|
||||
|
||||
func GetUserLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -10,145 +10,14 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relay2 "one-api/relay"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*func UpdateMidjourneyTask() {
|
||||
//revocer
|
||||
//imageModel := "midjourney"
|
||||
ctx := context.TODO()
|
||||
imageModel := "midjourney"
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("UpdateMidjourneyTask panic: %v", err)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
tasks := model.GetAllUnFinishTasks()
|
||||
if len(tasks) != 0 {
|
||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||
for _, task := range tasks {
|
||||
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
|
||||
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
|
||||
task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
|
||||
task.Status = "FAILURE"
|
||||
task.Progress = "100%"
|
||||
err := task.Update()
|
||||
if err != nil {
|
||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
|
||||
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
|
||||
|
||||
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
|
||||
if err != nil {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 5
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
|
||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask error: %v", err)
|
||||
continue
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
log.Printf("responseBody: %s", string(responseBody))
|
||||
var responseItem Midjourney
|
||||
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
|
||||
err = json.Unmarshal(responseBody, &responseItem)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
|
||||
var responseWithoutStatus MidjourneyWithoutStatus
|
||||
var responseStatus MidjourneyStatus
|
||||
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
|
||||
err2 := json.Unmarshal(responseBody, &responseStatus)
|
||||
if err1 == nil && err2 == nil {
|
||||
jsonData, err3 := json.Marshal(responseWithoutStatus)
|
||||
if err3 != nil {
|
||||
log.Printf("UpdateMidjourneyTask error1: %v", err3)
|
||||
continue
|
||||
}
|
||||
err4 := json.Unmarshal(jsonData, &responseStatus)
|
||||
if err4 != nil {
|
||||
log.Printf("UpdateMidjourneyTask error2: %v", err4)
|
||||
continue
|
||||
}
|
||||
responseItem.Status = strconv.Itoa(responseStatus.Status)
|
||||
} else {
|
||||
log.Printf("UpdateMidjourneyTask error3: %v", err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
log.Printf("UpdateMidjourneyTask error4: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
task.Code = 1
|
||||
task.Progress = responseItem.Progress
|
||||
task.PromptEn = responseItem.PromptEn
|
||||
task.State = responseItem.State
|
||||
task.SubmitTime = responseItem.SubmitTime
|
||||
task.StartTime = responseItem.StartTime
|
||||
task.FinishTime = responseItem.FinishTime
|
||||
task.ImageUrl = responseItem.ImageUrl
|
||||
task.Status = responseItem.Status
|
||||
task.FailReason = responseItem.FailReason
|
||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
||||
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
if err != nil {
|
||||
log.Println("error update user quota cache: " + err.Error())
|
||||
} else {
|
||||
modelRatio := common.GetModelRatio(imageModel)
|
||||
groupRatio := common.GetGroupRatio("default")
|
||||
ratio := modelRatio * groupRatio
|
||||
quota := int(ratio * 1 * 1000)
|
||||
if quota != 0 {
|
||||
err := model.IncreaseUserQuota(task.UserId, quota)
|
||||
if err != nil {
|
||||
log.Println("fail to increase user quota")
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask error5: %v", err)
|
||||
}
|
||||
log.Printf("UpdateMidjourneyTask success: %v", task)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
func UpdateMidjourneyTaskBulk() {
|
||||
//imageModel := "midjourney"
|
||||
ctx := context.TODO()
|
||||
@@ -228,12 +97,16 @@ func UpdateMidjourneyTaskBulk() {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
continue
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
continue
|
||||
}
|
||||
var responseItems []relay2.Midjourney
|
||||
var responseItems []dto.MidjourneyDto
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
@@ -245,10 +118,16 @@ func UpdateMidjourneyTaskBulk() {
|
||||
|
||||
for _, responseItem := range responseItems {
|
||||
task := taskM[responseItem.MjId]
|
||||
|
||||
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
|
||||
// 如果时间超过一小时,且进度不是100%,则认为任务失败
|
||||
if useTime > 3600000 && task.Progress != "100%" {
|
||||
responseItem.FailReason = "上游任务超时(超过1小时)"
|
||||
responseItem.Status = "FAILURE"
|
||||
}
|
||||
if !checkMjTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
|
||||
task.Code = 1
|
||||
task.Progress = responseItem.Progress
|
||||
task.PromptEn = responseItem.PromptEn
|
||||
@@ -259,6 +138,15 @@ func UpdateMidjourneyTaskBulk() {
|
||||
task.ImageUrl = responseItem.ImageUrl
|
||||
task.Status = responseItem.Status
|
||||
task.FailReason = responseItem.FailReason
|
||||
if responseItem.Properties != nil {
|
||||
propertiesStr, _ := json.Marshal(responseItem.Properties)
|
||||
task.Properties = string(propertiesStr)
|
||||
}
|
||||
if responseItem.Buttons != nil {
|
||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||
task.Buttons = string(buttonStr)
|
||||
}
|
||||
|
||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
@@ -286,7 +174,7 @@ func UpdateMidjourneyTaskBulk() {
|
||||
}
|
||||
}
|
||||
|
||||
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask relay2.Midjourney) bool {
|
||||
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
|
||||
if oldTask.Code != 1 {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -5,12 +5,29 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestStatus(c *gin.Context) {
|
||||
err := model.PingDB()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"success": false,
|
||||
"message": "数据库连接失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Server is running",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
@@ -20,6 +37,8 @@ func GetStatus(c *gin.Context) {
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDoClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
@@ -43,6 +62,8 @@ func GetStatus(c *gin.Context) {
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": common.PayAddress != "" && common.EpayId != "" && common.EpayKey != "",
|
||||
"mj_notify_enabled": constant.MjNotifyEnabled,
|
||||
"version": common.Version,
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel/ai360"
|
||||
"one-api/relay/channel/moonshot"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -59,8 +60,8 @@ func init() {
|
||||
IsBlocking: false,
|
||||
})
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
for i := 0; i < constant.APITypeDummy; i++ {
|
||||
if i == constant.APITypeAIProxyLibrary {
|
||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
||||
if i == relayconstant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
adaptor := relay.GetAdaptor(i)
|
||||
@@ -100,6 +101,17 @@ func init() {
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||
openAIModels = append(openAIModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
openAIModelsMap = make(map[string]OpenAIModels)
|
||||
for _, model := range openAIModels {
|
||||
openAIModelsMap[model.Id] = model
|
||||
|
||||
@@ -50,6 +50,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "LinuxDoOAuthEnabled":
|
||||
if option.Value == "true" && common.LinuxDoClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 LINUX DO OAuth,请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
@@ -38,83 +37,58 @@ func Relay(c *gin.Context) {
|
||||
retryTimes = common.RetryTimes
|
||||
}
|
||||
if retryTimes > 0 {
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
||||
} else {
|
||||
if err.StatusCode == http.StatusTooManyRequests {
|
||||
//err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
//err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||
err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
|
||||
c.JSON(err.StatusCode, gin.H{
|
||||
"error": err.OpenAIError,
|
||||
"error": err.Error,
|
||||
})
|
||||
}
|
||||
channelId := c.GetInt("channel_id")
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
|
||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||
if service.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
|
||||
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
||||
channelId := c.GetInt("channel_id")
|
||||
channelName := c.GetString("channel_name")
|
||||
service.DisableChannel(channelId, channelName, err.Message)
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RelayMidjourney(c *gin.Context) {
|
||||
relayMode := relayconstant.RelayModeUnknown
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
|
||||
relayMode = relayconstant.RelayModeMidjourneyImagine
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
|
||||
relayMode = relayconstant.RelayModeMidjourneyBlend
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
|
||||
relayMode = relayconstant.RelayModeMidjourneyDescribe
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
|
||||
relayMode = relayconstant.RelayModeMidjourneyNotify
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
|
||||
relayMode = relayconstant.RelayModeMidjourneyChange
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
|
||||
relayMode = relayconstant.RelayModeMidjourneyChange
|
||||
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
|
||||
relayMode = relayconstant.RelayModeMidjourneyTaskFetch
|
||||
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
|
||||
relayMode = relayconstant.RelayModeMidjourneyTaskFetchByCondition
|
||||
}
|
||||
|
||||
relayMode := c.GetInt("relay_mode")
|
||||
var err *dto.MidjourneyResponse
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeMidjourneyNotify:
|
||||
err = relay.RelayMidjourneyNotify(c)
|
||||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||
err = relay.RelayMidjourneyTask(c, relayMode)
|
||||
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||||
err = relay.RelayMidjourneyTaskImageSeed(c)
|
||||
case relayconstant.RelayModeSwapFace:
|
||||
err = relay.RelaySwapFace(c)
|
||||
default:
|
||||
err = relay.RelayMidjourneySubmit(c, relayMode)
|
||||
}
|
||||
//err = relayMidjourneySubmit(c, relayMode)
|
||||
log.Println(err)
|
||||
if err != nil {
|
||||
retryTimesStr := c.Query("retry")
|
||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||
if retryTimesStr == "" {
|
||||
retryTimes = common.RetryTimes
|
||||
}
|
||||
if retryTimes > 0 {
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
||||
} else {
|
||||
if err.Code == 30 {
|
||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||
}
|
||||
c.JSON(429, gin.H{
|
||||
"error": fmt.Sprintf("%s %s", err.Description, err.Result),
|
||||
"type": "upstream_error",
|
||||
})
|
||||
statusCode := http.StatusBadRequest
|
||||
if err.Code == 30 {
|
||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||
statusCode = http.StatusTooManyRequests
|
||||
}
|
||||
c.JSON(statusCode, gin.H{
|
||||
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
||||
"type": "upstream_error",
|
||||
"code": err.Code,
|
||||
})
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||
//if shouldDisableChannel(&err.OpenAIError) {
|
||||
// channelId := c.GetInt("channel_id")
|
||||
// channelName := c.GetString("channel_name")
|
||||
// disableChannel(channelId, channelName, err.Result)
|
||||
//};''''''''''''''''''''''''''''''''
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -111,6 +112,33 @@ func RequestEpay(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
|
||||
}
|
||||
|
||||
// tradeNo lock
|
||||
var orderLocks sync.Map
|
||||
var createLock sync.Mutex
|
||||
|
||||
// LockOrder 尝试对给定订单号加锁
|
||||
func LockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
createLock.Lock()
|
||||
defer createLock.Unlock()
|
||||
lock, ok = orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
lock = new(sync.Mutex)
|
||||
orderLocks.Store(tradeNo, lock)
|
||||
}
|
||||
}
|
||||
lock.(*sync.Mutex).Lock()
|
||||
}
|
||||
|
||||
// UnlockOrder 释放给定订单号的锁
|
||||
func UnlockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if ok {
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func EpayNotify(c *gin.Context) {
|
||||
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
|
||||
r[t] = c.Request.URL.Query().Get(t)
|
||||
@@ -122,6 +150,7 @@ func EpayNotify(c *gin.Context) {
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
return
|
||||
}
|
||||
}
|
||||
verifyInfo, err := client.Verify(params)
|
||||
@@ -135,11 +164,19 @@ func EpayNotify(c *gin.Context) {
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
log.Println("易支付回调签名验证失败")
|
||||
return
|
||||
}
|
||||
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
log.Println(verifyInfo)
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
|
||||
if topUp == nil {
|
||||
log.Printf("易支付回调未找到订单: %v", verifyInfo)
|
||||
return
|
||||
}
|
||||
if topUp.Status == "pending" {
|
||||
topUp.Status = "success"
|
||||
err := topUp.Update()
|
||||
|
||||
@@ -65,6 +65,7 @@ func setupLogin(user *model.User, c *gin.Context) {
|
||||
session.Set("username", user.Username)
|
||||
session.Set("role", user.Role)
|
||||
session.Set("status", user.Status)
|
||||
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -683,7 +684,7 @@ func ManageUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := user.HardDelete(); err != nil {
|
||||
if err := user.Delete(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
|
||||
@@ -2,18 +2,17 @@ version: '3.4'
|
||||
|
||||
services:
|
||||
new-api:
|
||||
image: calciumion/new-api:latest
|
||||
# build: .
|
||||
image: pengzhile/new-api:latest
|
||||
container_name: new-api
|
||||
restart: always
|
||||
command: --log-dir /app/logs
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- ./data:/data
|
||||
- ./data/new-api:/data
|
||||
- ./logs:/app/logs
|
||||
environment:
|
||||
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- SESSION_SECRET=random_string # 修改为随机字符串
|
||||
- TZ=Asia/Shanghai
|
||||
@@ -23,13 +22,22 @@ services:
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
healthcheck:
|
||||
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
- db
|
||||
|
||||
redis:
|
||||
image: redis:latest
|
||||
container_name: redis
|
||||
restart: always
|
||||
|
||||
db:
|
||||
image: mysql:8.2.0
|
||||
container_name: mysql
|
||||
restart: always
|
||||
volumes:
|
||||
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
|
||||
environment:
|
||||
TZ: Asia/Shanghai # 设置时区
|
||||
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
|
||||
MYSQL_USER: newapi # 创建专用用户
|
||||
MYSQL_PASSWORD: '123456' # 设置专用用户密码
|
||||
MYSQL_DATABASE: new-api # 自动创建数据库
|
||||
45
dto/error.go
45
dto/error.go
@@ -8,6 +8,47 @@ type OpenAIError struct {
|
||||
}
|
||||
|
||||
type OpenAIErrorWithStatusCode struct {
|
||||
OpenAIError
|
||||
StatusCode int `json:"status_code"`
|
||||
Error OpenAIError `json:"error"`
|
||||
StatusCode int `json:"status_code"`
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
Response struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
func (e GeneralErrorResponse) ToMessage() string {
|
||||
if e.Error.Message != "" {
|
||||
return e.Error.Message
|
||||
}
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
if e.Msg != "" {
|
||||
return e.Msg
|
||||
}
|
||||
if e.Err != "" {
|
||||
return e.Err
|
||||
}
|
||||
if e.ErrorMsg != "" {
|
||||
return e.ErrorMsg
|
||||
}
|
||||
if e.Header.Message != "" {
|
||||
return e.Header.Message
|
||||
}
|
||||
if e.Response.Error.Message != "" {
|
||||
return e.Response.Error.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,7 +1,21 @@
|
||||
package dto
|
||||
|
||||
//type SimpleMjRequest struct {
|
||||
// Prompt string `json:"prompt"`
|
||||
// CustomId string `json:"customId"`
|
||||
// Action string `json:"action"`
|
||||
// Content string `json:"content"`
|
||||
//}
|
||||
|
||||
type SwapFaceRequest struct {
|
||||
SourceBase64 string `json:"sourceBase64"`
|
||||
TargetBase64 string `json:"targetBase64"`
|
||||
}
|
||||
|
||||
type MidjourneyRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
CustomId string `json:"customId"`
|
||||
BotType string `json:"botType"`
|
||||
NotifyHook string `json:"notifyHook"`
|
||||
Action string `json:"action"`
|
||||
Index int `json:"index"`
|
||||
@@ -9,6 +23,7 @@ type MidjourneyRequest struct {
|
||||
TaskId string `json:"taskId"`
|
||||
Base64Array []string `json:"base64Array"`
|
||||
Content string `json:"content"`
|
||||
MaskBase64 string `json:"maskBase64"`
|
||||
}
|
||||
|
||||
type MidjourneyResponse struct {
|
||||
@@ -17,3 +32,64 @@ type MidjourneyResponse struct {
|
||||
Properties interface{} `json:"properties"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
type MidjourneyResponseWithStatusCode struct {
|
||||
StatusCode int `json:"statusCode"`
|
||||
Response MidjourneyResponse
|
||||
}
|
||||
|
||||
type MidjourneyDto struct {
|
||||
MjId string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
CustomId string `json:"customId"`
|
||||
BotType string `json:"botType"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submitTime"`
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
Buttons any `json:"buttons"`
|
||||
MaskBase64 string `json:"maskBase64"`
|
||||
Properties *Properties `json:"properties"`
|
||||
}
|
||||
|
||||
type MidjourneyStatus struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
type MidjourneyWithoutStatus struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
}
|
||||
|
||||
type ActionButton struct {
|
||||
CustomId any `json:"customId"`
|
||||
Emoji any `json:"emoji"`
|
||||
Label any `json:"label"`
|
||||
Type any `json:"type"`
|
||||
Style any `json:"style"`
|
||||
}
|
||||
|
||||
type Properties struct {
|
||||
FinalPrompt string `json:"finalPrompt"`
|
||||
FinalZhPrompt string `json:"finalZhPrompt"`
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ type GeneralOpenAIRequest struct {
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
@@ -82,6 +83,14 @@ func (m Message) StringContent() string {
|
||||
return string(m.Content)
|
||||
}
|
||||
|
||||
func (m Message) IsStringContent() bool {
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m Message) ParseContent() []MediaMessage {
|
||||
var contentList []MediaMessage
|
||||
var stringContent string
|
||||
@@ -130,9 +139,3 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
@@ -61,3 +61,9 @@ type CompletionsStreamResponse struct {
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
14
go.mod
14
go.mod
@@ -4,13 +4,12 @@ module one-api
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/chai2010/webp v1.1.1
|
||||
github.com/gin-contrib/cors v1.4.0
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
github.com/gin-contrib/static v0.0.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/go-playground/validator/v10 v10.16.0
|
||||
github.com/go-playground/validator/v10 v10.19.0
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.3.0
|
||||
@@ -19,7 +18,8 @@ require (
|
||||
github.com/samber/lo v1.38.1
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
|
||||
golang.org/x/crypto v0.17.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/image v0.15.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
@@ -32,7 +32,7 @@ require (
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
@@ -50,7 +50,7 @@ require (
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
@@ -64,9 +64,9 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
29
go.sum
29
go.sum
@@ -3,8 +3,6 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
|
||||
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
@@ -19,6 +17,8 @@ github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cn
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
|
||||
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
|
||||
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
|
||||
@@ -37,6 +37,7 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
|
||||
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
@@ -47,10 +48,10 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
|
||||
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
|
||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
|
||||
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
|
||||
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||
@@ -79,8 +80,6 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
|
||||
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
|
||||
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
|
||||
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
@@ -108,10 +107,10 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
|
||||
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
|
||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
@@ -176,15 +175,19 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -197,16 +200,14 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
@@ -15,6 +15,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
role := session.Get("role")
|
||||
id := session.Get("id")
|
||||
status := session.Get("status")
|
||||
linuxDoEnable := session.Get("linuxdo_enable")
|
||||
if username == nil {
|
||||
// Check access token
|
||||
accessToken := c.Request.Header.Get("Authorization")
|
||||
@@ -33,6 +34,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
role = user.Role
|
||||
id = user.Id
|
||||
status = user.Status
|
||||
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -50,6 +52,14 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户 LINUX DO 信任等级不足",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if role.(int) < minRole {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -100,16 +110,25 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
token, err := model.ValidateUserToken(key)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !userEnabled {
|
||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !linuxDoEnabled {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
|
||||
return
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
@@ -125,17 +144,11 @@ func TokenAuth() func(c *gin.Context) {
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
requestURL := c.Request.URL.String()
|
||||
consumeQuota := true
|
||||
if strings.HasPrefix(requestURL, "/v1/models") {
|
||||
consumeQuota = false
|
||||
}
|
||||
c.Set("consume_quota", consumeQuota)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("channelId", parts[1])
|
||||
} else {
|
||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,11 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -23,32 +27,59 @@ func Distribute() func(c *gin.Context) {
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
channel, err = model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
shouldSelectChannel := true
|
||||
// Select a channel for the user
|
||||
var modelRequest ModelRequest
|
||||
var err error
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
|
||||
// Midjourney
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "midjourney"
|
||||
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
|
||||
shouldSelectChannel = false
|
||||
} else {
|
||||
midjourneyRequest := dto.MidjourneyRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
|
||||
if err != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
|
||||
return
|
||||
}
|
||||
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
||||
if mjErr != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
|
||||
return
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
if !success {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||
return
|
||||
} else {
|
||||
// task fetch, task fetch by condition, notify
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
}
|
||||
modelRequest.Model = midjourneyModel
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
@@ -87,60 +118,64 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
if tokenModelLimit != nil {
|
||||
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
||||
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// token model limit is empty, all models are not allowed
|
||||
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
c.Set("group", userGroup)
|
||||
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
if shouldSelectChannel {
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
}
|
||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
|
||||
return
|
||||
}
|
||||
if channel == nil {
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
ban := true
|
||||
// parse *int to bool
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
if nil != channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
}
|
||||
c.Set("auto_ban", ban)
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
// 如果错误,而且渠道为空,说明是没有可用渠道
|
||||
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
||||
return
|
||||
}
|
||||
if channel == nil {
|
||||
abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
ban := true
|
||||
// parse *int to bool
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
c.Set("auto_ban", ban)
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||
@@ -15,3 +15,13 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
c.Abort()
|
||||
common.LogError(c.Request.Context(), message)
|
||||
}
|
||||
|
||||
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
|
||||
c.JSON(statusCode, gin.H{
|
||||
"description": description,
|
||||
"type": "new_api_error",
|
||||
"code": code,
|
||||
})
|
||||
c.Abort()
|
||||
common.LogError(c.Request.Context(), description)
|
||||
}
|
||||
|
||||
@@ -147,7 +147,12 @@ func FixAbility() (int, error) {
|
||||
return 0, err
|
||||
}
|
||||
var channels []Channel
|
||||
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
|
||||
|
||||
if len(abilityChannelIds) == 0 {
|
||||
err = DB.Find(&channels).Error
|
||||
} else {
|
||||
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -204,6 +204,30 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
return userEnabled, err
|
||||
}
|
||||
|
||||
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
|
||||
if !common.RedisEnabled {
|
||||
return IsLinuxDoEnabled(userId)
|
||||
}
|
||||
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
|
||||
if err == nil {
|
||||
return enabled == "1", nil
|
||||
}
|
||||
|
||||
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
enabled = "0"
|
||||
if linuxDoEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set linuxdo enabled error: " + err.Error())
|
||||
}
|
||||
return linuxDoEnabled, err
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"one-api/common"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -60,6 +62,7 @@ func chooseDB() (*gorm.DB, error) {
|
||||
dsn += "?parseTime=true"
|
||||
}
|
||||
}
|
||||
common.UsingMySQL = true
|
||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
@@ -90,6 +93,9 @@ func InitDB() (err error) {
|
||||
if !common.IsMasterNode {
|
||||
return nil
|
||||
}
|
||||
if common.UsingMySQL {
|
||||
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
||||
}
|
||||
common.SysLog("database migration started")
|
||||
err = db.AutoMigrate(&Channel{})
|
||||
if err != nil {
|
||||
@@ -148,3 +154,33 @@ func CloseDB() error {
|
||||
err = sqlDB.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
lastPingTime time.Time
|
||||
pingMutex sync.Mutex
|
||||
)
|
||||
|
||||
func PingDB() error {
|
||||
pingMutex.Lock()
|
||||
defer pingMutex.Unlock()
|
||||
|
||||
if time.Since(lastPingTime) < time.Second*10 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlDB, err := DB.DB()
|
||||
if err != nil {
|
||||
log.Printf("Error getting sql.DB from GORM: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = sqlDB.Ping()
|
||||
if err != nil {
|
||||
log.Printf("Error pinging DB: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
lastPingTime = time.Now()
|
||||
common.SysLog("Database pinged successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,21 +4,23 @@ type Midjourney struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action"`
|
||||
Action string `json:"action" gorm:"type:varchar(40);index"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
SubmitTime int64 `json:"submit_time" gorm:"index"`
|
||||
StartTime int64 `json:"start_time" gorm:"index"`
|
||||
FinishTime int64 `json:"finish_time" gorm:"index"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
Status string `json:"status" gorm:"type:varchar(20);index"`
|
||||
Progress string `json:"progress" gorm:"type:varchar(30);index"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
Quota int `json:"quota"`
|
||||
Buttons string `json:"buttons"`
|
||||
Properties string `json:"properties"`
|
||||
}
|
||||
|
||||
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
|
||||
|
||||
@@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -30,6 +31,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
|
||||
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
|
||||
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
|
||||
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
|
||||
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
|
||||
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
|
||||
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
||||
@@ -65,6 +67,9 @@ func InitOptionMap() {
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["LinuxDoClientId"] = ""
|
||||
common.OptionMap["LinuxDoClientSecret"] = ""
|
||||
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
common.OptionMap["TelegramBotName"] = ""
|
||||
common.OptionMap["WeChatServerAddress"] = ""
|
||||
@@ -88,6 +93,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
|
||||
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
|
||||
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
|
||||
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
@@ -155,6 +161,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.EmailVerificationEnabled = boolValue
|
||||
case "GitHubOAuthEnabled":
|
||||
common.GitHubOAuthEnabled = boolValue
|
||||
case "LinuxDoOAuthEnabled":
|
||||
common.LinuxDoOAuthEnabled = boolValue
|
||||
case "WeChatAuthEnabled":
|
||||
common.WeChatAuthEnabled = boolValue
|
||||
case "TelegramOAuthEnabled":
|
||||
@@ -181,6 +189,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.DataExportEnabled = boolValue
|
||||
case "DefaultCollapseSidebar":
|
||||
common.DefaultCollapseSidebar = boolValue
|
||||
case "MjNotifyEnabled":
|
||||
constant.MjNotifyEnabled = boolValue
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
@@ -217,6 +227,12 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.GitHubClientId = value
|
||||
case "GitHubClientSecret":
|
||||
common.GitHubClientSecret = value
|
||||
case "LinuxDoClientId":
|
||||
common.LinuxDoClientId = value
|
||||
case "LinuxDoClientSecret":
|
||||
common.LinuxDoClientSecret = value
|
||||
case "LinuxDoMinLevel":
|
||||
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
|
||||
case "Footer":
|
||||
common.Footer = value
|
||||
case "SystemName":
|
||||
|
||||
@@ -8,16 +8,17 @@ import (
|
||||
)
|
||||
|
||||
type Redemption struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index"`
|
||||
Quota int `json:"quota" gorm:"default:100"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
|
||||
Count int `json:"count" gorm:"-:all"` // only for api request
|
||||
UsedUserId int `json:"used_user_id"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index"`
|
||||
Quota int `json:"quota" gorm:"default:100"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
|
||||
Count int `json:"count" gorm:"-:all"` // only for api request
|
||||
UsedUserId int `json:"used_user_id"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) {
|
||||
|
||||
@@ -10,19 +10,20 @@ import (
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index" `
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
|
||||
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index" `
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
|
||||
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
||||
|
||||
@@ -21,6 +21,8 @@ type User struct {
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
|
||||
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
@@ -272,6 +274,14 @@ func (user *User) FillUserByGitHubId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByLinuxDoId() error {
|
||||
if user.LinuxDoId == "" {
|
||||
return errors.New("LINUX DO id 为空!")
|
||||
}
|
||||
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByWeChatId() error {
|
||||
if user.WeChatId == "" {
|
||||
return errors.New("WeChat id 为空!")
|
||||
@@ -311,6 +321,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
|
||||
return DB.Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsUsernameAlreadyTaken(username string) bool {
|
||||
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
@@ -356,6 +370,18 @@ func IsUserEnabled(userId int) (bool, error) {
|
||||
return user.Status == common.UserStatusEnabled, nil
|
||||
}
|
||||
|
||||
func IsLinuxDoEnabled(userId int) (bool, error) {
|
||||
if userId == 0 {
|
||||
return false, errors.New("user id is empty")
|
||||
}
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
|
||||
}
|
||||
|
||||
func ValidateAccessToken(token string) (user *User) {
|
||||
if token == "" {
|
||||
return nil
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
package ali
|
||||
|
||||
type AliMessage struct {
|
||||
User string `json:"user"`
|
||||
Bot string `json:"bot"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
History []AliMessage `json:"history"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
//History []AliMessage `json:"history,omitempty"`
|
||||
Messages []AliMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
}
|
||||
|
||||
type AliChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliInput `json:"input"`
|
||||
Input AliInput `json:"input,omitempty"`
|
||||
Parameters AliParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -14,41 +14,35 @@ import (
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
|
||||
const EnableSearchModelSuffix = "-internet"
|
||||
|
||||
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
prompt := ""
|
||||
//prompt := ""
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.StringContent(),
|
||||
Bot: "Okay",
|
||||
})
|
||||
continue
|
||||
} else {
|
||||
if i == len(request.Messages)-1 {
|
||||
prompt = message.StringContent()
|
||||
break
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.StringContent(),
|
||||
Bot: string(request.Messages[i+1].Content),
|
||||
})
|
||||
i++
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
Content: message.StringContent(),
|
||||
Role: strings.ToLower(message.Role),
|
||||
})
|
||||
}
|
||||
enableSearch := false
|
||||
aliModel := request.Model
|
||||
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
|
||||
enableSearch = true
|
||||
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
||||
}
|
||||
return &AliChatRequest{
|
||||
Model: request.Model,
|
||||
Input: AliInput{
|
||||
Prompt: prompt,
|
||||
History: messages,
|
||||
//Prompt: prompt,
|
||||
Messages: messages,
|
||||
},
|
||||
Parameters: AliParameters{
|
||||
IncrementalOutput: request.Stream,
|
||||
Seed: uint64(request.Seed),
|
||||
EnableSearch: enableSearch,
|
||||
},
|
||||
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
|
||||
// TopP: request.TopP,
|
||||
// TopK: 50,
|
||||
// //Seed: 0,
|
||||
// //EnableSearch: false,
|
||||
//},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +71,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
@@ -242,7 +236,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
|
||||
@@ -24,21 +24,10 @@ var baiduTokenStore sync.Map
|
||||
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
return &BaiduChatRequest{
|
||||
Messages: messages,
|
||||
@@ -184,7 +173,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
@@ -220,7 +209,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
|
||||
@@ -9,18 +9,32 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
RequestModeCompletion = 1
|
||||
RequestModeMessage = 2
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
} else {
|
||||
a.RequestMode = RequestModeCompletion
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
|
||||
if a.RequestMode == RequestModeMessage {
|
||||
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
|
||||
} else {
|
||||
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
@@ -38,7 +52,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
if a.RequestMode == RequestModeCompletion {
|
||||
return requestOpenAI2ClaudeComplete(*request), nil
|
||||
} else {
|
||||
return requestOpenAI2ClaudeMessage(*request)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
@@ -47,11 +65,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = claudeStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
|
||||
} else {
|
||||
err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
package claude
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
|
||||
"claude-instant-1.2",
|
||||
"claude-2",
|
||||
"claude-2.0",
|
||||
"claude-2.1",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -4,14 +4,36 @@ type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeMediaMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *ClaudeMessageSource `json:"source,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeMessageSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type ClaudeMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
@@ -22,8 +44,25 @@ type ClaudeError struct {
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Content []ClaudeMediaMessage `json:"content"`
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
Index int `json:"index"` // stream only
|
||||
Delta *ClaudeMediaMessage `json:"delta"` // stream only
|
||||
Message *ClaudeResponse `json:"message"` // stream only: message_start
|
||||
}
|
||||
|
||||
//type ClaudeResponseChoice struct {
|
||||
// Index int `json:"index"`
|
||||
// Type string `json:"type"`
|
||||
//}
|
||||
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "stop_sequence":
|
||||
return "stop"
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
default:
|
||||
@@ -24,7 +26,7 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||
func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
@@ -44,7 +46,9 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||
} else if message.Role == "assistant" {
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||
} else if message.Role == "system" {
|
||||
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
|
||||
if prompt == "" {
|
||||
prompt = message.StringContent()
|
||||
}
|
||||
}
|
||||
}
|
||||
prompt += "\n\nAssistant:"
|
||||
@@ -52,51 +56,154 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
Stream: textRequest.Stream,
|
||||
}
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = 4096
|
||||
}
|
||||
claudeMessages := make([]ClaudeMessage, 0)
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "system" {
|
||||
claudeRequest.System = message.StringContent()
|
||||
} else {
|
||||
claudeMessage := ClaudeMessage{
|
||||
Role: message.Role,
|
||||
}
|
||||
if message.IsStringContent() {
|
||||
claudeMessage.Content = message.StringContent()
|
||||
} else {
|
||||
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
|
||||
for _, mediaMessage := range message.ParseContent() {
|
||||
claudeMediaMessage := ClaudeMediaMessage{
|
||||
Type: mediaMessage.Type,
|
||||
}
|
||||
if mediaMessage.Type == "text" {
|
||||
claudeMediaMessage.Text = mediaMessage.Text
|
||||
} else {
|
||||
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
||||
claudeMediaMessage.Type = "image"
|
||||
claudeMediaMessage.Source = &ClaudeMessageSource{
|
||||
Type: "base64",
|
||||
}
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
|
||||
claudeMediaMessage.Source.MediaType = mimeType
|
||||
claudeMediaMessage.Source.Data = data
|
||||
} else {
|
||||
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claudeMediaMessage.Source.MediaType = "image/" + format
|
||||
claudeMediaMessage.Source.Data = base64String
|
||||
}
|
||||
}
|
||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||
}
|
||||
claudeMessage.Content = claudeMediaMessages
|
||||
}
|
||||
claudeMessages = append(claudeMessages, claudeMessage)
|
||||
}
|
||||
}
|
||||
claudeRequest.Prompt = ""
|
||||
claudeRequest.Messages = claudeMessages
|
||||
|
||||
return &claudeRequest, nil
|
||||
}
|
||||
|
||||
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
||||
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
var claudeUsage *ClaudeUsage
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if reqMode == RequestModeCompletion {
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
response.Id = claudeResponse.Message.Id
|
||||
response.Model = claudeResponse.Message.Model
|
||||
claudeUsage = &claudeResponse.Message.Usage
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
choice.Index = claudeResponse.Index
|
||||
choice.Delta.Content = claudeResponse.Delta.Text
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
claudeUsage = &claudeResponse.Usage
|
||||
}
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
return &response, claudeUsage
|
||||
}
|
||||
|
||||
func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
||||
choices := make([]dto.OpenAITextResponseChoice, 0)
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []dto.OpenAITextResponseChoice{choice},
|
||||
}
|
||||
if reqMode == RequestModeCompletion {
|
||||
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
choices = append(choices, choice)
|
||||
} else {
|
||||
fullTextResponse.Id = claudeResponse.Id
|
||||
for i, message := range claudeResponse.Content {
|
||||
content, _ := json.Marshal(message.Text)
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
choices = append(choices, choice)
|
||||
}
|
||||
}
|
||||
|
||||
fullTextResponse.Choices = choices
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
var usage dto.Usage
|
||||
responseText := ""
|
||||
createdTime := common.GetTimestamp()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
|
||||
return i + 4, data[0:i], nil
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
@@ -108,10 +215,10 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if !strings.HasPrefix(data, "event: completion") {
|
||||
if !strings.HasPrefix(data, "data: ") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
@@ -128,10 +235,31 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
responseText += claudeResponse.Completion
|
||||
response := streamResponseClaude2OpenAI(&claudeResponse)
|
||||
|
||||
response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
if requestMode == RequestModeCompletion {
|
||||
responseText += claudeResponse.Completion
|
||||
responseId = response.Id
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
responseId = claudeResponse.Message.Id
|
||||
modelName = claudeResponse.Message.Model
|
||||
usage.PromptTokens = claudeUsage.InputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
responseText += claudeResponse.Delta.Text
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
usage.CompletionTokens = claudeUsage.OutputTokens
|
||||
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
//response.Id = responseId
|
||||
response.Id = responseId
|
||||
response.Created = createdTime
|
||||
response.Model = modelName
|
||||
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
@@ -146,12 +274,19 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, responseText
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage = *service.ResponseText2Usage(responseText, modelName, promptTokens)
|
||||
} else {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage = *service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -167,7 +302,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
}
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
@@ -176,12 +311,17 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
||||
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
completionTokens := service.CountTokenText(claudeResponse.Completion, model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
usage := dto.Usage{}
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = completionTokens
|
||||
usage.TotalTokens = promptTokens + completionTokens
|
||||
} else {
|
||||
usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
|
||||
@@ -20,6 +20,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
version := "v1"
|
||||
if info.ApiVersion != "" {
|
||||
version = info.ApiVersion
|
||||
}
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
|
||||
@@ -140,8 +140,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
},
|
||||
FinishReason: relaycommon.StopFinishReason,
|
||||
}
|
||||
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
|
||||
choice.Message.Content = content
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
@@ -246,7 +246,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
|
||||
59
relay/channel/ollama/adaptor.go
Normal file
59
relay/channel/ollama/adaptor.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/chat", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return requestOpenAI2Ollama(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
5
relay/channel/ollama/constants.go
Normal file
5
relay/channel/ollama/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package ollama
|
||||
|
||||
var ModelList []string
|
||||
|
||||
var ChannelName = "ollama"
|
||||
18
relay/channel/ollama/dto.go
Normal file
18
relay/channel/ollama/dto.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package ollama
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Options *OllamaOptions `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaOptions struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Topp float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
}
|
||||
32
relay/channel/ollama/relay-ollama.go
Normal file
32
relay/channel/ollama/relay-ollama.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package ollama
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
messages = append(messages, dto.Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
str, ok := request.Stop.(string)
|
||||
var Stop []string
|
||||
if ok {
|
||||
Stop = []string{str}
|
||||
} else {
|
||||
Stop, _ = request.Stop.([]string)
|
||||
}
|
||||
return &OllamaRequest{
|
||||
Model: request.Model,
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
Options: &OllamaOptions{
|
||||
Temperature: request.Temperature,
|
||||
Seed: request.Seed,
|
||||
Topp: request.TopP,
|
||||
TopK: request.TopK,
|
||||
Stop: Stop,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -49,11 +49,14 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
|
||||
req.Header.Set("api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
req.Header.Set("X-Title", "One API")
|
||||
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
||||
req.Header.Set("OpenAI-Organization", info.Organization)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
//if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
// req.Header.Set("X-Title", "One API")
|
||||
//}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -127,8 +127,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
Error: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
|
||||
@@ -146,7 +146,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
|
||||
63
relay/channel/perplexity/adaptor.go
Normal file
63
relay/channel/perplexity/adaptor.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package perplexity
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if request.TopP >= 1 {
|
||||
request.TopP = 0.99
|
||||
}
|
||||
return requestOpenAI2Perplexity(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
7
relay/channel/perplexity/constants.go
Normal file
7
relay/channel/perplexity/constants.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package perplexity
|
||||
|
||||
var ModelList = []string{
|
||||
"sonar-small-chat", "sonar-small-online", "sonar-medium-chat", "sonar-medium-online", "mistral-7b-instruct", "mixtral-8x7b-instruct",
|
||||
}
|
||||
|
||||
var ChannelName = "perplexity"
|
||||
21
relay/channel/perplexity/relay-perplexity.go
Normal file
21
relay/channel/perplexity/relay-perplexity.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package perplexity
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
messages = append(messages, dto.Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
return &dto.GeneralOpenAIRequest{
|
||||
Model: request.Model,
|
||||
Stream: request.Stream,
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxTokens: request.MaxTokens,
|
||||
}
|
||||
}
|
||||
@@ -175,7 +175,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
|
||||
}
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
},
|
||||
|
||||
@@ -24,8 +24,9 @@ import (
|
||||
|
||||
func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
||||
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
||||
shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5")
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
if message.Role == "system" && shouldCovertSystemMessage {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
@@ -126,7 +127,7 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
}
|
||||
|
||||
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
@@ -156,7 +157,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
|
||||
}
|
||||
|
||||
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
@@ -235,20 +236,44 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
||||
return dataChan, stopChan, nil
|
||||
}
|
||||
|
||||
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
}
|
||||
if apiVersion == "" {
|
||||
apiVersion = "v1.1"
|
||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||
}
|
||||
domain := "general"
|
||||
if apiVersion != "v1.1" {
|
||||
domain += strings.Split(apiVersion, ".")[0]
|
||||
func apiVersion2domain(apiVersion string) string {
|
||||
switch apiVersion {
|
||||
case "v1.1":
|
||||
return "general"
|
||||
case "v2.1":
|
||||
return "generalv2"
|
||||
case "v3.1":
|
||||
return "generalv3"
|
||||
case "v3.5":
|
||||
return "generalv3.5"
|
||||
}
|
||||
return "general" + apiVersion
|
||||
}
|
||||
|
||||
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
|
||||
apiVersion := getAPIVersion(c, modelName)
|
||||
domain := apiVersion2domain(apiVersion)
|
||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||
return domain, authUrl
|
||||
}
|
||||
|
||||
func getAPIVersion(c *gin.Context, modelName string) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion != "" {
|
||||
return apiVersion
|
||||
}
|
||||
parts := strings.Split(modelName, "-")
|
||||
if len(parts) == 2 {
|
||||
apiVersion = parts[1]
|
||||
return apiVersion
|
||||
|
||||
}
|
||||
apiVersion = c.GetString("api_version")
|
||||
if apiVersion != "" {
|
||||
return apiVersion
|
||||
}
|
||||
apiVersion = "v1.1"
|
||||
common.SysLog("api_version not found, using default: " + apiVersion)
|
||||
return apiVersion
|
||||
}
|
||||
|
||||
@@ -36,6 +36,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if request.TopP >= 1 {
|
||||
request.TopP = 0.99
|
||||
}
|
||||
return requestOpenAI2Zhipu(*request), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -244,7 +244,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
|
||||
}
|
||||
if !zhipuResponse.Success {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: zhipuResponse.Msg,
|
||||
Type: "zhipu_error",
|
||||
Param: "",
|
||||
|
||||
@@ -34,6 +34,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if request.TopP >= 1 {
|
||||
request.TopP = 0.99
|
||||
}
|
||||
return requestOpenAI2Zhipu(*request), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -234,8 +234,8 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
Error: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
|
||||
@@ -24,6 +24,7 @@ type RelayInfo struct {
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
ApiKey string
|
||||
Organization string
|
||||
BaseUrl string
|
||||
}
|
||||
|
||||
@@ -52,12 +53,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
info.ApiVersion = GetAzureAPIVersion(c)
|
||||
info.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
@@ -17,10 +17,10 @@ import (
|
||||
|
||||
var StopFinishReason = "stop"
|
||||
|
||||
func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
openAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||
func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: dto.OpenAIError{
|
||||
Error: dto.OpenAIError{
|
||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
@@ -40,7 +40,7 @@ func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.Open
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
||||
OpenAIErrorWithStatusCode.Error = textResponse.Error
|
||||
return
|
||||
}
|
||||
|
||||
@@ -66,12 +66,3 @@ func GetAPIVersion(c *gin.Context) string {
|
||||
}
|
||||
return apiVersion
|
||||
}
|
||||
|
||||
func GetAzureAPIVersion(c *gin.Context) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
}
|
||||
return apiVersion
|
||||
}
|
||||
|
||||
@@ -16,6 +16,8 @@ const (
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
APITypeZhipu_v4
|
||||
APITypeOllama
|
||||
APITypePerplexity
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -43,6 +45,10 @@ func ChannelType2APIType(channelType int) int {
|
||||
apiType = APITypeGemini
|
||||
case common.ChannelTypeZhipu_v4:
|
||||
apiType = APITypeZhipu_v4
|
||||
case common.ChannelTypeOllama:
|
||||
apiType = APITypeOllama
|
||||
case common.ChannelTypePerplexity:
|
||||
apiType = APITypePerplexity
|
||||
}
|
||||
return apiType
|
||||
}
|
||||
|
||||
@@ -17,10 +17,15 @@ const (
|
||||
RelayModeMidjourneySimpleChange
|
||||
RelayModeMidjourneyNotify
|
||||
RelayModeMidjourneyTaskFetch
|
||||
RelayModeMidjourneyTaskImageSeed
|
||||
RelayModeMidjourneyTaskFetchByCondition
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
RelayModeMidjourneyAction
|
||||
RelayModeMidjourneyModal
|
||||
RelayModeMidjourneyShorten
|
||||
RelayModeSwapFace
|
||||
)
|
||||
|
||||
func Path2RelayMode(path string) int {
|
||||
@@ -48,3 +53,39 @@ func Path2RelayMode(path string) int {
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
func Path2RelayModeMidjourney(path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(path, "/mj/submit/action") {
|
||||
// midjourney plus
|
||||
relayMode = RelayModeMidjourneyAction
|
||||
} else if strings.HasPrefix(path, "/mj/submit/modal") {
|
||||
// midjourney plus
|
||||
relayMode = RelayModeMidjourneyModal
|
||||
} else if strings.HasPrefix(path, "/mj/submit/shorten") {
|
||||
// midjourney plus
|
||||
relayMode = RelayModeMidjourneyShorten
|
||||
} else if strings.HasPrefix(path, "/mj/insight-face/swap") {
|
||||
// midjourney plus
|
||||
relayMode = RelayModeSwapFace
|
||||
} else if strings.HasPrefix(path, "/mj/submit/imagine") {
|
||||
relayMode = RelayModeMidjourneyImagine
|
||||
} else if strings.HasPrefix(path, "/mj/submit/blend") {
|
||||
relayMode = RelayModeMidjourneyBlend
|
||||
} else if strings.HasPrefix(path, "/mj/submit/describe") {
|
||||
relayMode = RelayModeMidjourneyDescribe
|
||||
} else if strings.HasPrefix(path, "/mj/notify") {
|
||||
relayMode = RelayModeMidjourneyNotify
|
||||
} else if strings.HasPrefix(path, "/mj/submit/change") {
|
||||
relayMode = RelayModeMidjourneyChange
|
||||
} else if strings.HasPrefix(path, "/mj/submit/simple-change") {
|
||||
relayMode = RelayModeMidjourneyChange
|
||||
} else if strings.HasSuffix(path, "/fetch") {
|
||||
relayMode = RelayModeMidjourneyTaskFetch
|
||||
} else if strings.HasSuffix(path, "/image-seed") {
|
||||
relayMode = RelayModeMidjourneyTaskImageSeed
|
||||
} else if strings.HasSuffix(path, "/list-by-condition") {
|
||||
relayMode = RelayModeMidjourneyTaskFetchByCondition
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiVersion := relaycommon.GetAzureAPIVersion(c)
|
||||
apiVersion := relaycommon.GetAPIVersion(c)
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
|
||||
}
|
||||
|
||||
|
||||
@@ -24,16 +24,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
startTime := time.Now()
|
||||
|
||||
var imageRequest dto.ImageRequest
|
||||
if consumeQuota {
|
||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if imageRequest.Model == "" {
|
||||
@@ -136,7 +133,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
|
||||
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
|
||||
|
||||
if consumeQuota && userQuota-quota < 0 {
|
||||
if userQuota-quota < 0 {
|
||||
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
@@ -176,46 +173,42 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
var textResponse dto.ImageResponse
|
||||
defer func(ctx context.Context) {
|
||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
||||
if consumeQuota {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
if consumeQuota {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relayconstant "one-api/relay/constant"
|
||||
@@ -20,53 +21,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Midjourney struct {
|
||||
MjId string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submitTime"`
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
}
|
||||
|
||||
type MidjourneyStatus struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
type MidjourneyWithoutStatus struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
}
|
||||
|
||||
var DefaultModelPrice = map[string]float64{
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
}
|
||||
|
||||
func RelayMidjourneyImage(c *gin.Context) {
|
||||
taskId := c.Param("id")
|
||||
midjourneyTask := model.GetByOnlyMJId(taskId)
|
||||
@@ -108,7 +62,7 @@ func RelayMidjourneyImage(c *gin.Context) {
|
||||
}
|
||||
|
||||
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
|
||||
var midjRequest Midjourney
|
||||
var midjRequest dto.MidjourneyDto
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
@@ -147,7 +101,7 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
|
||||
func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
|
||||
midjourneyTask.MjId = originTask.MjId
|
||||
midjourneyTask.Progress = originTask.Progress
|
||||
midjourneyTask.PromptEn = originTask.PromptEn
|
||||
@@ -167,9 +121,164 @@ func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
midjourneyTask.Action = originTask.Action
|
||||
midjourneyTask.Description = originTask.Description
|
||||
midjourneyTask.Prompt = originTask.Prompt
|
||||
if originTask.Buttons != "" {
|
||||
var buttons []dto.ActionButton
|
||||
err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
|
||||
if err == nil {
|
||||
midjourneyTask.Buttons = buttons
|
||||
}
|
||||
}
|
||||
if originTask.Properties != "" {
|
||||
var properties dto.Properties
|
||||
err := json.Unmarshal([]byte(originTask.Properties), &properties)
|
||||
if err == nil {
|
||||
midjourneyTask.Properties = &properties
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
var swapFaceRequest dto.SwapFaceRequest
|
||||
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||||
}
|
||||
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||
modelPrice := common.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if modelPrice == -1 {
|
||||
defaultPrice, ok := common.DefaultModelPrice[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: err.Error(),
|
||||
}
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "quota_not_enough",
|
||||
}
|
||||
}
|
||||
requestURL := c.Request.URL.String()
|
||||
baseURL := c.GetString("base_url")
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
|
||||
if err != nil {
|
||||
return &mjResp.Response
|
||||
}
|
||||
defer func(ctx context.Context) {
|
||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}(c.Request.Context())
|
||||
midjResponse := &mjResp.Response
|
||||
midjourneyTask := &model.Midjourney{
|
||||
UserId: userId,
|
||||
Code: midjResponse.Code,
|
||||
Action: constant.MjActionSwapFace,
|
||||
MjId: midjResponse.Result,
|
||||
Prompt: "InsightFace",
|
||||
PromptEn: "",
|
||||
Description: midjResponse.Description,
|
||||
State: "",
|
||||
SubmitTime: startTime,
|
||||
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||
FinishTime: 0,
|
||||
ImageUrl: "",
|
||||
Status: "",
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
Quota: quota,
|
||||
}
|
||||
err = midjourneyTask.Insert()
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
|
||||
}
|
||||
c.Writer.WriteHeader(mjResp.StatusCode)
|
||||
respBody, err := json.Marshal(midjResponse)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||
}
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||||
taskId := c.Param("id")
|
||||
userId := c.GetInt("id")
|
||||
originTask := model.GetByMJId(userId, taskId)
|
||||
if originTask == nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
|
||||
}
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
||||
}
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
||||
requestURL := c.Request.URL.String()
|
||||
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
||||
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
|
||||
if err != nil {
|
||||
return &midjResponseWithStatus.Response
|
||||
}
|
||||
midjResponse := &midjResponseWithStatus.Response
|
||||
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||
respBody, err := json.Marshal(midjResponse)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||
}
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||||
userId := c.GetInt("id")
|
||||
var err error
|
||||
@@ -184,7 +293,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
||||
Description: "task_no_found",
|
||||
}
|
||||
}
|
||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
||||
midjourneyTask := coverMidjourneyTaskDto(c, originTask)
|
||||
respBody, err = json.Marshal(midjourneyTask)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
@@ -203,16 +312,16 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
||||
Description: "do_request_failed",
|
||||
}
|
||||
}
|
||||
var tasks []Midjourney
|
||||
var tasks []dto.MidjourneyDto
|
||||
if len(condition.IDs) != 0 {
|
||||
originTasks := model.GetByMJIds(userId, condition.IDs)
|
||||
for _, originTask := range originTasks {
|
||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
||||
midjourneyTask := coverMidjourneyTaskDto(c, originTask)
|
||||
tasks = append(tasks, midjourneyTask)
|
||||
}
|
||||
}
|
||||
if tasks == nil {
|
||||
tasks = make([]Midjourney, 0)
|
||||
tasks = make([]dto.MidjourneyDto, 0)
|
||||
}
|
||||
respBody, err = json.Marshal(tasks)
|
||||
if err != nil {
|
||||
@@ -235,170 +344,115 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// type 1 根据 mode 价格不同
|
||||
MJSubmitActionImagine = "IMAGINE"
|
||||
MJSubmitActionVariation = "VARIATION" //变换
|
||||
MJSubmitActionBlend = "BLEND" //混图
|
||||
|
||||
MJSubmitActionReroll = "REROLL" //重新生成
|
||||
// type 2 固定价格
|
||||
MJSubmitActionDescribe = "DESCRIBE"
|
||||
MJSubmitActionUpscale = "UPSCALE" // 放大
|
||||
)
|
||||
|
||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||||
imageModel := "midjourney"
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
//channelType := c.GetInt("channel")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
consumeQuota := true
|
||||
var midjRequest dto.MidjourneyRequest
|
||||
if consumeQuota {
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "bind_request_body_failed",
|
||||
}
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||||
}
|
||||
|
||||
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
|
||||
if mjErr != nil {
|
||||
return mjErr
|
||||
}
|
||||
relayMode = relayconstant.RelayModeMidjourneyChange
|
||||
}
|
||||
|
||||
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||
if midjRequest.Prompt == "" {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "prompt_is_required",
|
||||
}
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
|
||||
}
|
||||
midjRequest.Action = "IMAGINE"
|
||||
midjRequest.Action = constant.MjActionImagine
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||
midjRequest.Action = "DESCRIBE"
|
||||
midjRequest.Action = constant.MjActionDescribe
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
||||
midjRequest.Action = constant.MjActionShorten
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||
midjRequest.Action = "BLEND"
|
||||
midjRequest.Action = constant.MjActionBlend
|
||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||
mjId := ""
|
||||
if relayMode == relayconstant.RelayModeMidjourneyChange {
|
||||
if midjRequest.TaskId == "" {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "taskId_is_required",
|
||||
}
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
||||
} else if midjRequest.Action == "" {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "action_is_required",
|
||||
}
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
|
||||
} else if midjRequest.Index == 0 {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "index_can_only_be_1_2_3_4",
|
||||
}
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required")
|
||||
}
|
||||
//action = midjRequest.Action
|
||||
mjId = midjRequest.TaskId
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
||||
if midjRequest.Content == "" {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "content_is_required",
|
||||
}
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
||||
}
|
||||
params := convertSimpleChangeParams(midjRequest.Content)
|
||||
params := service.ConvertSimpleChangeParams(midjRequest.Content)
|
||||
if params == nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "content_parse_failed",
|
||||
}
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
|
||||
}
|
||||
mjId = params.ID
|
||||
mjId = params.TaskId
|
||||
midjRequest.Action = params.Action
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
||||
//if midjRequest.MaskBase64 == "" {
|
||||
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
|
||||
//}
|
||||
mjId = midjRequest.TaskId
|
||||
midjRequest.Action = constant.MjActionModal
|
||||
}
|
||||
|
||||
originTask := model.GetByMJId(userId, mjId)
|
||||
if originTask == nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_no_found",
|
||||
}
|
||||
} else if originTask.Action == "UPSCALE" {
|
||||
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "upscale_task_can_not_be_change",
|
||||
}
|
||||
} else if originTask.Status != "SUCCESS" {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_status_is_not_success",
|
||||
}
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, false)
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "channel_not_found",
|
||||
}
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
||||
}
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
||||
}
|
||||
midjRequest.Prompt = originTask.Prompt
|
||||
|
||||
//if channelType == common.ChannelTypeMidjourneyPlus {
|
||||
// // plus
|
||||
//} else {
|
||||
// // 普通版渠道
|
||||
//
|
||||
//}
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_model_mapping_failed",
|
||||
}
|
||||
}
|
||||
if modelMap[imageModel] != "" {
|
||||
imageModel = modelMap[imageModel]
|
||||
isModelMapped = true
|
||||
}
|
||||
if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
|
||||
consumeQuota = false
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
//baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
baseURL := c.GetString("base_url")
|
||||
|
||||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
log.Printf("fullRequestURL: %s", fullRequestURL)
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(midjRequest)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "marshal_text_request_failed",
|
||||
}
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
|
||||
modelPrice := common.GetModelPrice(mjAction, true)
|
||||
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||||
modelPrice := common.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if modelPrice == -1 {
|
||||
defaultPrice, ok := DefaultModelPrice[mjAction]
|
||||
defaultPrice, ok := common.DefaultModelPrice[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
@@ -423,53 +477,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "create_request_failed",
|
||||
}
|
||||
return &midjResponseWithStatus.Response
|
||||
}
|
||||
//req.Header.Set("ApiKey", c.Request.Header.Get("ApiKey"))
|
||||
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
//mjToken := ""
|
||||
//if c.Request.Header.Get("ApiKey") != "" {
|
||||
// mjToken = strings.Split(c.Request.Header.Get("ApiKey"), " ")[1]
|
||||
//}
|
||||
//req.Header.Set("ApiKey", "Bearer midjourney-proxy")
|
||||
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
|
||||
// print request header
|
||||
log.Printf("request header: %s", req.Header)
|
||||
log.Printf("request body: %s", midjRequest.Prompt)
|
||||
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "do_request_failed",
|
||||
}
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_request_body_failed",
|
||||
}
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_request_body_failed",
|
||||
}
|
||||
}
|
||||
var midjResponse dto.MidjourneyResponse
|
||||
midjResponse := &midjResponseWithStatus.Response
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if consumeQuota {
|
||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
@@ -481,7 +496,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
@@ -489,41 +504,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
//if consumeQuota {
|
||||
//
|
||||
//}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "read_response_body_failed",
|
||||
}
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_response_body_failed",
|
||||
}
|
||||
}
|
||||
|
||||
err = json.Unmarshal(responseBody, &midjResponse)
|
||||
log.Printf("responseBody: %s", string(responseBody))
|
||||
log.Printf("midjResponse: %v", midjResponse)
|
||||
if resp.StatusCode != 200 {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
|
||||
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
||||
//1-提交成功
|
||||
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
|
||||
@@ -575,8 +555,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
}
|
||||
//修改返回值
|
||||
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
|
||||
responseBody = []byte(newBody)
|
||||
if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
|
||||
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
|
||||
responseBody = []byte(newBody)
|
||||
}
|
||||
}
|
||||
|
||||
err = midjourneyTask.Insert()
|
||||
@@ -593,21 +575,22 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
responseBody = []byte(newBody)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
//for k, v := range resp.Header {
|
||||
// c.Writer.Header().Set(k, v[0])
|
||||
//}
|
||||
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
_, err = io.Copy(c.Writer, bodyReader)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "copy_response_body_failed",
|
||||
}
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
err = bodyReader.Close()
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
@@ -622,32 +605,3 @@ type taskChangeParams struct {
|
||||
Action string
|
||||
Index int
|
||||
}
|
||||
|
||||
func convertSimpleChangeParams(content string) *taskChangeParams {
|
||||
split := strings.Split(content, " ")
|
||||
if len(split) != 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
action := strings.ToLower(split[1])
|
||||
changeParams := &taskChangeParams{}
|
||||
changeParams.ID = split[0]
|
||||
|
||||
if action[0] == 'u' {
|
||||
changeParams.Action = "UPSCALE"
|
||||
} else if action[0] == 'v' {
|
||||
changeParams.Action = "VARIATION"
|
||||
} else if action == "r" {
|
||||
changeParams.Action = "REROLL"
|
||||
return changeParams
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(action[1:2])
|
||||
if err != nil || index < 1 || index > 4 {
|
||||
return nil
|
||||
}
|
||||
changeParams.Index = index
|
||||
return changeParams
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -59,7 +60,6 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
|
||||
}
|
||||
}
|
||||
relayInfo.IsStream = textRequest.Stream
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
@@ -85,9 +85,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
}
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
// set upstream model name
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice := common.GetModelPrice(textRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
@@ -148,10 +150,19 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
return service.RelayErrorHandler(resp)
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
|
||||
@@ -169,6 +180,8 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model), nil
|
||||
case relayconstant.RelayModeModerations:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model), nil
|
||||
default:
|
||||
err = errors.New("unknown relay mode")
|
||||
promptTokens = 0
|
||||
@@ -216,6 +229,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
return preConsumedQuota, userQuota, nil
|
||||
}
|
||||
|
||||
func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsumedQuota int) {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysError("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/ollama"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/palm"
|
||||
"one-api/relay/channel/perplexity"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
@@ -39,6 +41,10 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &zhipu.Adaptor{}
|
||||
case constant.APITypeZhipu_v4:
|
||||
return &zhipu_4v.Adaptor{}
|
||||
case constant.APITypeOllama:
|
||||
return &ollama.Adaptor{}
|
||||
case constant.APITypePerplexity:
|
||||
return &perplexity.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
||||
{
|
||||
apiRouter.GET("/status", controller.GetStatus)
|
||||
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
|
||||
apiRouter.GET("/notice", controller.GetNotice)
|
||||
apiRouter.GET("/about", controller.GetAbout)
|
||||
apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
@@ -22,6 +23,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxDoOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
|
||||
|
||||
@@ -47,6 +47,9 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
|
||||
relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
|
||||
@@ -54,7 +57,9 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/notify", controller.RelayMidjourney)
|
||||
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
|
||||
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
|
||||
}
|
||||
//relayMjRouter.Use()
|
||||
}
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: code,
|
||||
Description: desc,
|
||||
}
|
||||
}
|
||||
|
||||
func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
|
||||
return &dto.MidjourneyResponseWithStatusCode{
|
||||
StatusCode: statusCode,
|
||||
Response: *MidjourneyErrorWrapper(code, desc),
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
|
||||
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
||||
text := err.Error()
|
||||
@@ -23,7 +41,42 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError
|
||||
Code: code,
|
||||
}
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: openAIError,
|
||||
StatusCode: statusCode,
|
||||
Error: openAIError,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
Error: dto.OpenAIError{
|
||||
Message: "",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var errResponse dto.GeneralErrorResponse
|
||||
err = json.Unmarshal(responseBody, &errResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if errResponse.Error.Message != "" {
|
||||
// OpenAI format error, so we override the default one
|
||||
errWithStatusCode.Error = errResponse.Error
|
||||
} else {
|
||||
errWithStatusCode.Error.Message = errResponse.ToMessage()
|
||||
}
|
||||
if errWithStatusCode.Error.Message == "" {
|
||||
errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
235
service/midjourney.go
Normal file
235
service/midjourney.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func CoverActionToModelName(mjAction string) string {
|
||||
modelName := "mj_" + strings.ToLower(mjAction)
|
||||
if mjAction == constant.MjActionSwapFace {
|
||||
modelName = "swap_face"
|
||||
}
|
||||
return modelName
|
||||
}
|
||||
|
||||
func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
|
||||
action := ""
|
||||
if relayMode == relayconstant.RelayModeMidjourneyAction {
|
||||
// plus request
|
||||
err := CoverPlusActionToNormalAction(midjRequest)
|
||||
if err != nil {
|
||||
return "", err, false
|
||||
}
|
||||
action = midjRequest.Action
|
||||
} else {
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeMidjourneyImagine:
|
||||
action = constant.MjActionImagine
|
||||
case relayconstant.RelayModeMidjourneyDescribe:
|
||||
action = constant.MjActionDescribe
|
||||
case relayconstant.RelayModeMidjourneyBlend:
|
||||
action = constant.MjActionBlend
|
||||
case relayconstant.RelayModeMidjourneyShorten:
|
||||
action = constant.MjActionShorten
|
||||
case relayconstant.RelayModeMidjourneyChange:
|
||||
action = midjRequest.Action
|
||||
case relayconstant.RelayModeMidjourneyModal:
|
||||
action = constant.MjActionModal
|
||||
case relayconstant.RelayModeSwapFace:
|
||||
action = constant.MjActionSwapFace
|
||||
case relayconstant.RelayModeMidjourneySimpleChange:
|
||||
params := ConvertSimpleChangeParams(midjRequest.Content)
|
||||
if params == nil {
|
||||
return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
|
||||
}
|
||||
action = params.Action
|
||||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
|
||||
return "", nil, true
|
||||
default:
|
||||
return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
|
||||
}
|
||||
}
|
||||
modelName := CoverActionToModelName(action)
|
||||
return modelName, nil, true
|
||||
}
|
||||
|
||||
func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
|
||||
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
|
||||
customId := midjRequest.CustomId
|
||||
if customId == "" {
|
||||
return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
|
||||
}
|
||||
splits := strings.Split(customId, "::")
|
||||
var action string
|
||||
if splits[1] == "JOB" {
|
||||
action = splits[2]
|
||||
} else {
|
||||
action = splits[1]
|
||||
}
|
||||
|
||||
if action == "" {
|
||||
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
||||
}
|
||||
if strings.Contains(action, "upsample") {
|
||||
index, err := strconv.Atoi(splits[3])
|
||||
if err != nil {
|
||||
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
||||
}
|
||||
midjRequest.Index = index
|
||||
midjRequest.Action = constant.MjActionUpscale
|
||||
} else if strings.Contains(action, "variation") {
|
||||
midjRequest.Index = 1
|
||||
if action == "variation" {
|
||||
index, err := strconv.Atoi(splits[3])
|
||||
if err != nil {
|
||||
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
||||
}
|
||||
midjRequest.Index = index
|
||||
midjRequest.Action = constant.MjActionVariation
|
||||
} else if action == "low_variation" {
|
||||
midjRequest.Action = constant.MjActionLowVariation
|
||||
} else if action == "high_variation" {
|
||||
midjRequest.Action = constant.MjActionHighVariation
|
||||
}
|
||||
} else if strings.Contains(action, "pan") {
|
||||
midjRequest.Action = constant.MjActionPan
|
||||
midjRequest.Index = 1
|
||||
} else if strings.Contains(action, "reroll") {
|
||||
midjRequest.Action = constant.MjActionReRoll
|
||||
midjRequest.Index = 1
|
||||
} else if action == "Outpaint" {
|
||||
midjRequest.Action = constant.MjActionZoom
|
||||
midjRequest.Index = 1
|
||||
} else if action == "CustomZoom" {
|
||||
midjRequest.Action = constant.MjActionCustomZoom
|
||||
midjRequest.Index = 1
|
||||
} else if action == "Inpaint" {
|
||||
midjRequest.Action = constant.MjActionInPaint
|
||||
midjRequest.Index = 1
|
||||
} else {
|
||||
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
|
||||
split := strings.Split(content, " ")
|
||||
if len(split) != 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
action := strings.ToLower(split[1])
|
||||
changeParams := &dto.MidjourneyRequest{}
|
||||
changeParams.TaskId = split[0]
|
||||
|
||||
if action[0] == 'u' {
|
||||
changeParams.Action = "UPSCALE"
|
||||
} else if action[0] == 'v' {
|
||||
changeParams.Action = "VARIATION"
|
||||
} else if action == "r" {
|
||||
changeParams.Action = "REROLL"
|
||||
return changeParams
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(action[1:2])
|
||||
if err != nil || index < 1 || index > 4 {
|
||||
return nil
|
||||
}
|
||||
changeParams.Index = index
|
||||
return changeParams
|
||||
}
|
||||
|
||||
func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
|
||||
var nullBytes []byte
|
||||
//var requestBody io.Reader
|
||||
//requestBody = c.Request.Body
|
||||
// read request body to json, delete accountFilter and notifyHook
|
||||
var mapResult map[string]interface{}
|
||||
// if get request, no need to read request body
|
||||
if c.Request.Method != "GET" {
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
delete(mapResult, "accountFilter")
|
||||
if !constant.MjNotifyEnabled {
|
||||
delete(mapResult, "notifyHook")
|
||||
}
|
||||
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
// make new request with mapResult
|
||||
}
|
||||
reqBody, err := json.Marshal(mapResult)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
|
||||
defer cancel()
|
||||
resp, err := GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
common.SysError("do request failed: " + err.Error())
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
statusCode := resp.StatusCode
|
||||
//if statusCode != 200 {
|
||||
// return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
|
||||
//}
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
var midjResponse dto.MidjourneyResponse
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
||||
}
|
||||
respStr := string(responseBody)
|
||||
log.Printf("responseBody: %s", respStr)
|
||||
if respStr == "" {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
|
||||
} else {
|
||||
err = json.Unmarshal(responseBody, &midjResponse)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
||||
}
|
||||
}
|
||||
//log.Printf("midjResponse: %v", midjResponse)
|
||||
//for k, v := range resp.Header {
|
||||
// c.Writer.Header().Set(k, v[0])
|
||||
//}
|
||||
return &dto.MidjourneyResponseWithStatusCode{
|
||||
StatusCode: statusCode,
|
||||
Response: midjResponse,
|
||||
}, responseBody, nil
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
||||
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
|
||||
config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
4
web/.gitignore
vendored
4
web/.gitignore
vendored
@@ -21,6 +21,4 @@
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.idea
|
||||
package-lock.json
|
||||
yarn.lock
|
||||
.idea/
|
||||
@@ -49,8 +49,9 @@
|
||||
]
|
||||
},
|
||||
"devDependencies": {
|
||||
"prettier": "^2.7.1",
|
||||
"typescript": "4.4.2"
|
||||
"prettier": "2.8.8",
|
||||
"typescript": "4.4.2",
|
||||
"@babel/plugin-proposal-private-property-in-object": "^7.21.11"
|
||||
},
|
||||
"prettier": {
|
||||
"singleQuote": true,
|
||||
|
||||
379
web/src/App.js
379
web/src/App.js
@@ -8,12 +8,12 @@ import LoginForm from './components/LoginForm';
|
||||
import NotFound from './pages/NotFound';
|
||||
import Setting from './pages/Setting';
|
||||
import EditUser from './pages/User/EditUser';
|
||||
import { API, getLogo, getSystemName, showError, showNotice } from './helpers';
|
||||
import { getLogo, getSystemName } from './helpers';
|
||||
import PasswordResetForm from './components/PasswordResetForm';
|
||||
import GitHubOAuth from './components/GitHubOAuth';
|
||||
import LinuxDoOAuth from "./components/LinuxDoOAuth";
|
||||
import PasswordResetConfirm from './components/PasswordResetConfirm';
|
||||
import { UserContext } from './context/User';
|
||||
import { StatusContext } from './context/Status';
|
||||
import Channel from './pages/Channel';
|
||||
import Token from './pages/Token';
|
||||
import EditChannel from './pages/Channel/EditChannel';
|
||||
@@ -21,12 +21,13 @@ import Redemption from './pages/Redemption';
|
||||
import TopUp from './pages/TopUp';
|
||||
import Log from './pages/Log';
|
||||
import Chat from './pages/Chat';
|
||||
import {Layout} from "@douyinfe/semi-ui";
|
||||
import Midjourney from "./pages/Midjourney";
|
||||
import Detail from "./pages/Detail";
|
||||
import { Layout } from '@douyinfe/semi-ui';
|
||||
import Midjourney from './pages/Midjourney';
|
||||
import Detail from './pages/Detail';
|
||||
|
||||
const Home = lazy(() => import('./pages/Home'));
|
||||
const About = lazy(() => import('./pages/About'));
|
||||
|
||||
function App() {
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
// const [statusState, statusDispatch] = useContext(StatusContext);
|
||||
@@ -47,7 +48,7 @@ function App() {
|
||||
}
|
||||
let logo = getLogo();
|
||||
if (logo) {
|
||||
let linkElement = document.querySelector("link[rel~='icon']");
|
||||
let linkElement = document.querySelector('link[rel~=\'icon\']');
|
||||
if (linkElement) {
|
||||
linkElement.href = logo;
|
||||
}
|
||||
@@ -56,185 +57,193 @@ function App() {
|
||||
|
||||
return (
|
||||
<Layout>
|
||||
<Layout.Content>
|
||||
<Routes>
|
||||
<Route
|
||||
path='/'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Home />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/channel'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Channel />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/channel/edit/:id'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditChannel />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/channel/add'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditChannel />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/token'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Token />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/redemption'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Redemption />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<User />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user/edit/:id'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditUser />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user/edit'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditUser />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user/reset'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<PasswordResetConfirm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/login'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<LoginForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/register'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<RegisterForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/reset'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<PasswordResetForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/github'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<GitHubOAuth />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/setting'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Setting />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/topup'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<TopUp />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/log'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Log />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/detail'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Detail />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/midjourney'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Midjourney />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/about'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<About />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/chat'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Chat />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route path='*' element={
|
||||
<NotFound />
|
||||
} />
|
||||
</Routes>
|
||||
</Layout.Content>
|
||||
<Layout.Content>
|
||||
<Routes>
|
||||
<Route
|
||||
path="/"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Home />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/channel"
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Channel />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/channel/edit/:id"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditChannel />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/channel/add"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditChannel />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/token"
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Token />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/redemption"
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Redemption />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/user"
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<User />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/user/edit/:id"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditUser />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/user/edit"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditUser />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/user/reset"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<PasswordResetConfirm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/login"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<LoginForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/register"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<RegisterForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/reset"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<PasswordResetForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/oauth/github"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<GitHubOAuth />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/oauth/linuxdo"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<LinuxDoOAuth />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/setting"
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Setting />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/topup"
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<TopUp />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/log"
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Log />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/detail"
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Detail />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/midjourney"
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Midjourney />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/about"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<About />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/chat"
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Chat />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route path="*" element={
|
||||
<NotFound />
|
||||
} />
|
||||
</Routes>
|
||||
</Layout.Content>
|
||||
</Layout>
|
||||
);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
|
||||
import { getFooterHTML, getSystemName } from '../helpers';
|
||||
import {Layout} from "@douyinfe/semi-ui";
|
||||
import { Layout } from '@douyinfe/semi-ui';
|
||||
|
||||
const Footer = () => {
|
||||
const systemName = getSystemName();
|
||||
@@ -29,30 +29,30 @@ const Footer = () => {
|
||||
|
||||
return (
|
||||
<Layout>
|
||||
<Layout.Content style={{textAlign: 'center'}}>
|
||||
<Layout.Content style={{ textAlign: 'center' }}>
|
||||
{footer ? (
|
||||
<div
|
||||
className='custom-footer'
|
||||
className="custom-footer"
|
||||
dangerouslySetInnerHTML={{ __html: footer }}
|
||||
></div>
|
||||
) : (
|
||||
<div className='custom-footer'>
|
||||
<div className="custom-footer">
|
||||
<a
|
||||
href='https://github.com/Calcium-Ion/new-api'
|
||||
target='_blank'
|
||||
href="https://github.com/Calcium-Ion/new-api"
|
||||
target="_blank" rel="noreferrer"
|
||||
>
|
||||
New API {process.env.REACT_APP_VERSION}{' '}
|
||||
</a>
|
||||
由{' '}
|
||||
<a href='https://github.com/Calcium-Ion' target='_blank'>
|
||||
<a href="https://github.com/Calcium-Ion" target="_blank" rel="noreferrer">
|
||||
Calcium-Ion
|
||||
</a>{' '}
|
||||
开发,基于{' '}
|
||||
<a href='https://github.com/songquanpeng/one-api' target='_blank'>
|
||||
<a href="https://github.com/songquanpeng/one-api" target="_blank" rel="noreferrer">
|
||||
One API v0.5.4
|
||||
</a>{' '}
|
||||
,本项目根据{' '}
|
||||
<a href='https://opensource.org/licenses/mit-license.php'>
|
||||
<a href="https://opensource.org/licenses/mit-license.php">
|
||||
MIT 许可证
|
||||
</a>{' '}
|
||||
授权
|
||||
|
||||
@@ -14,7 +14,8 @@ const GitHubOAuth = () => {
|
||||
let navigate = useNavigate();
|
||||
|
||||
const sendCode = async (code, state, count) => {
|
||||
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
|
||||
let aff = localStorage.getItem('aff');
|
||||
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}&aff=${aff}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
if (message === 'bind') {
|
||||
@@ -41,6 +42,14 @@ const GitHubOAuth = () => {
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let error = searchParams.get('error');
|
||||
if (error) {
|
||||
let errorDescription = searchParams.get('error_description');
|
||||
showError(`授权错误:${error}: ${errorDescription}`);
|
||||
navigate('/setting');
|
||||
return;
|
||||
}
|
||||
|
||||
let code = searchParams.get('code');
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
@@ -49,7 +58,7 @@ const GitHubOAuth = () => {
|
||||
return (
|
||||
<Segment style={{ minHeight: '300px' }}>
|
||||
<Dimmer active inverted>
|
||||
<Loader size='large'>{prompt}</Loader>
|
||||
<Loader size="large">{prompt}</Loader>
|
||||
</Dimmer>
|
||||
</Segment>
|
||||
);
|
||||
|
||||
@@ -1,165 +1,161 @@
|
||||
import React, {useContext, useEffect, useRef, useState} from 'react';
|
||||
import {Link, useNavigate} from 'react-router-dom';
|
||||
import {UserContext} from '../context/User';
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { Link, useNavigate } from 'react-router-dom';
|
||||
import { UserContext } from '../context/User';
|
||||
|
||||
import {API, getLogo, getSystemName, isAdmin, isMobile, showSuccess} from '../helpers';
|
||||
import { API, getLogo, getSystemName, showSuccess } from '../helpers';
|
||||
import '../index.css';
|
||||
|
||||
import fireworks from 'react-fireworks';
|
||||
|
||||
import {
|
||||
IconKey,
|
||||
IconUser,
|
||||
IconHelpCircle
|
||||
} from '@douyinfe/semi-icons';
|
||||
import {Nav, Avatar, Dropdown, Layout, Switch} from '@douyinfe/semi-ui';
|
||||
import {stringToColor} from "../helpers/render";
|
||||
import { IconHelpCircle, IconKey, IconUser } from '@douyinfe/semi-icons';
|
||||
import { Avatar, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui';
|
||||
import { stringToColor } from '../helpers/render';
|
||||
|
||||
// HeaderBar Buttons
|
||||
let headerButtons = [
|
||||
{
|
||||
text: '关于',
|
||||
itemKey: 'about',
|
||||
to: '/about',
|
||||
icon: <IconHelpCircle/>
|
||||
},
|
||||
{
|
||||
text: '关于',
|
||||
itemKey: 'about',
|
||||
to: '/about',
|
||||
icon: <IconHelpCircle />
|
||||
}
|
||||
];
|
||||
|
||||
if (localStorage.getItem('chat_link')) {
|
||||
headerButtons.splice(1, 0, {
|
||||
name: '聊天',
|
||||
to: '/chat',
|
||||
icon: 'comments'
|
||||
});
|
||||
headerButtons.splice(1, 0, {
|
||||
name: '聊天',
|
||||
to: '/chat',
|
||||
icon: 'comments'
|
||||
});
|
||||
}
|
||||
|
||||
const HeaderBar = () => {
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
let navigate = useNavigate();
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
let navigate = useNavigate();
|
||||
|
||||
const [showSidebar, setShowSidebar] = useState(false);
|
||||
const [dark, setDark] = useState(false);
|
||||
const systemName = getSystemName();
|
||||
const logo = getLogo();
|
||||
var themeMode = localStorage.getItem('theme-mode');
|
||||
const currentDate = new Date();
|
||||
// enable fireworks on new year(1.1 and 2.9-2.24)
|
||||
const isNewYear = (currentDate.getMonth() === 0 && currentDate.getDate() === 1) || (currentDate.getMonth() === 1 && currentDate.getDate() >= 9 && currentDate.getDate() <= 24);
|
||||
const [showSidebar, setShowSidebar] = useState(false);
|
||||
const [dark, setDark] = useState(false);
|
||||
const systemName = getSystemName();
|
||||
const logo = getLogo();
|
||||
var themeMode = localStorage.getItem('theme-mode');
|
||||
const currentDate = new Date();
|
||||
// enable fireworks on new year(1.1 and 2.9-2.24)
|
||||
const isNewYear = (currentDate.getMonth() === 0 && currentDate.getDate() === 1) || (currentDate.getMonth() === 1 && currentDate.getDate() >= 9 && currentDate.getDate() <= 24);
|
||||
|
||||
async function logout() {
|
||||
setShowSidebar(false);
|
||||
await API.get('/api/user/logout');
|
||||
showSuccess('注销成功!');
|
||||
userDispatch({type: 'logout'});
|
||||
localStorage.removeItem('user');
|
||||
navigate('/login');
|
||||
async function logout() {
|
||||
setShowSidebar(false);
|
||||
await API.get('/api/user/logout');
|
||||
showSuccess('注销成功!');
|
||||
userDispatch({ type: 'logout' });
|
||||
localStorage.removeItem('user');
|
||||
navigate('/login');
|
||||
}
|
||||
|
||||
const handleNewYearClick = () => {
|
||||
fireworks.init('root', {});
|
||||
fireworks.start();
|
||||
setTimeout(() => {
|
||||
fireworks.stop();
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 10000);
|
||||
}, 3000);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (themeMode === 'dark') {
|
||||
switchMode(true);
|
||||
}
|
||||
if (isNewYear) {
|
||||
console.log('Happy New Year!');
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleNewYearClick = () => {
|
||||
fireworks.init("root",{});
|
||||
fireworks.start();
|
||||
setTimeout(() => {
|
||||
fireworks.stop();
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 10000);
|
||||
}, 3000);
|
||||
};
|
||||
const switchMode = (model) => {
|
||||
const body = document.body;
|
||||
if (!model) {
|
||||
body.removeAttribute('theme-mode');
|
||||
localStorage.setItem('theme-mode', 'light');
|
||||
} else {
|
||||
body.setAttribute('theme-mode', 'dark');
|
||||
localStorage.setItem('theme-mode', 'dark');
|
||||
}
|
||||
setDark(model);
|
||||
};
|
||||
return (
|
||||
<>
|
||||
<Layout>
|
||||
<div style={{ width: '100%' }}>
|
||||
<Nav
|
||||
mode={'horizontal'}
|
||||
// bodyStyle={{ height: 100 }}
|
||||
renderWrapper={({ itemElement, isSubNav, isInSubNav, props }) => {
|
||||
const routerMap = {
|
||||
about: '/about',
|
||||
login: '/login',
|
||||
register: '/register'
|
||||
};
|
||||
return (
|
||||
<Link
|
||||
style={{ textDecoration: 'none' }}
|
||||
to={routerMap[props.itemKey]}
|
||||
>
|
||||
{itemElement}
|
||||
</Link>
|
||||
);
|
||||
}}
|
||||
selectedKeys={[]}
|
||||
// items={headerButtons}
|
||||
onSelect={key => {
|
||||
|
||||
useEffect(() => {
|
||||
if (themeMode === 'dark') {
|
||||
switchMode(true);
|
||||
}
|
||||
if (isNewYear) {
|
||||
console.log('Happy New Year!');
|
||||
}
|
||||
}, []);
|
||||
|
||||
const switchMode = (model) => {
|
||||
const body = document.body;
|
||||
if (!model) {
|
||||
body.removeAttribute('theme-mode');
|
||||
localStorage.setItem('theme-mode', 'light');
|
||||
} else {
|
||||
body.setAttribute('theme-mode', 'dark');
|
||||
localStorage.setItem('theme-mode', 'dark');
|
||||
}
|
||||
setDark(model);
|
||||
};
|
||||
return (
|
||||
<>
|
||||
<Layout>
|
||||
<div style={{width: '100%'}}>
|
||||
<Nav
|
||||
mode={'horizontal'}
|
||||
// bodyStyle={{ height: 100 }}
|
||||
renderWrapper={({itemElement, isSubNav, isInSubNav, props}) => {
|
||||
const routerMap = {
|
||||
about: "/about",
|
||||
login: "/login",
|
||||
register: "/register",
|
||||
};
|
||||
return (
|
||||
<Link
|
||||
style={{textDecoration: "none"}}
|
||||
to={routerMap[props.itemKey]}
|
||||
>
|
||||
{itemElement}
|
||||
</Link>
|
||||
);
|
||||
}}
|
||||
selectedKeys={[]}
|
||||
// items={headerButtons}
|
||||
onSelect={key => {
|
||||
|
||||
}}
|
||||
footer={
|
||||
<>
|
||||
{isNewYear &&
|
||||
// happy new year
|
||||
<Dropdown
|
||||
position="bottomRight"
|
||||
render={
|
||||
<Dropdown.Menu>
|
||||
<Dropdown.Item onClick={handleNewYearClick}>Happy New Year!!!</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
<Nav.Item itemKey={'new-year'} text={'🏮'}/>
|
||||
</Dropdown>
|
||||
}
|
||||
<Nav.Item itemKey={'about'} icon={<IconHelpCircle />} />
|
||||
<Switch checkedText="🌞" size={'large'} checked={dark} uncheckedText="🌙" onChange={switchMode} />
|
||||
{userState.user ?
|
||||
<>
|
||||
<Dropdown
|
||||
position="bottomRight"
|
||||
render={
|
||||
<Dropdown.Menu>
|
||||
<Dropdown.Item onClick={logout}>退出</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
<Avatar size="small" color={stringToColor(userState.user.username)} style={{ margin: 4 }}>
|
||||
{userState.user.username[0]}
|
||||
</Avatar>
|
||||
<span>{userState.user.username}</span>
|
||||
</Dropdown>
|
||||
</>
|
||||
:
|
||||
<>
|
||||
<Nav.Item itemKey={'login'} text={'登录'} icon={<IconKey />} />
|
||||
<Nav.Item itemKey={'register'} text={'注册'} icon={<IconUser />} />
|
||||
</>
|
||||
}
|
||||
</>
|
||||
}
|
||||
}}
|
||||
footer={
|
||||
<>
|
||||
{isNewYear &&
|
||||
// happy new year
|
||||
<Dropdown
|
||||
position="bottomRight"
|
||||
render={
|
||||
<Dropdown.Menu>
|
||||
<Dropdown.Item onClick={handleNewYearClick}>Happy New Year!!!</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
<Nav.Item itemKey={'new-year'} text={'🏮'} />
|
||||
</Dropdown>
|
||||
}
|
||||
<Nav.Item itemKey={'about'} icon={<IconHelpCircle />} />
|
||||
<Switch checkedText="🌞" size={'large'} checked={dark} uncheckedText="🌙" onChange={switchMode} />
|
||||
{userState.user ?
|
||||
<>
|
||||
<Dropdown
|
||||
position="bottomRight"
|
||||
render={
|
||||
<Dropdown.Menu>
|
||||
<Dropdown.Item onClick={logout}>退出</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
</Nav>
|
||||
</div>
|
||||
</Layout>
|
||||
</>
|
||||
);
|
||||
<Avatar size="small" color={stringToColor(userState.user.username)} style={{ margin: 4 }}>
|
||||
{userState.user.username[0]}
|
||||
</Avatar>
|
||||
<span>{userState.user.username}</span>
|
||||
</Dropdown>
|
||||
</>
|
||||
:
|
||||
<>
|
||||
<Nav.Item itemKey={'login'} text={'登录'} icon={<IconKey />} />
|
||||
<Nav.Item itemKey={'register'} text={'注册'} icon={<IconUser />} />
|
||||
</>
|
||||
}
|
||||
</>
|
||||
}
|
||||
>
|
||||
</Nav>
|
||||
</div>
|
||||
</Layout>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default HeaderBar;
|
||||
|
||||
21
web/src/components/LinuxDoIcon.js
Normal file
21
web/src/components/LinuxDoIcon.js
Normal file
@@ -0,0 +1,21 @@
|
||||
import React from 'react';
|
||||
import {Icon} from '@douyinfe/semi-ui';
|
||||
|
||||
const LinuxDoIcon = (props) => {
|
||||
function CustomIcon() {
|
||||
return <svg className='icon' viewBox='0 0 24 24' version='1.1'
|
||||
xmlns='http://www.w3.org/2000/svg' width='16' height='16' {...props}>
|
||||
<path
|
||||
d="M19.7,17.6c-0.1-0.2-0.2-0.4-0.2-0.6c0-0.4-0.2-0.7-0.5-1c-0.1-0.1-0.3-0.2-0.4-0.2c0.6-1.8-0.3-3.6-1.3-4.9c0,0,0,0,0,0c-0.8-1.2-2-2.1-1.9-3.7c0-1.9,0.2-5.4-3.3-5.1C8.5,2.3,9.5,6,9.4,7.3c0,1.1-0.5,2.2-1.3,3.1c-0.2,0.2-0.4,0.5-0.5,0.7c-1,1.2-1.5,2.8-1.5,4.3c-0.2,0.2-0.4,0.4-0.5,0.6c-0.1,0.1-0.2,0.2-0.2,0.3c-0.1,0.1-0.3,0.2-0.5,0.3c-0.4,0.1-0.7,0.3-0.9,0.7c-0.1,0.3-0.2,0.7-0.1,1.1c0.1,0.2,0.1,0.4,0,0.7c-0.2,0.4-0.2,0.9,0,1.4c0.3,0.4,0.8,0.5,1.5,0.6c0.5,0,1.1,0.2,1.6,0.4l0,0c0.5,0.3,1.1,0.5,1.7,0.5c0.3,0,0.7-0.1,1-0.2c0.3-0.2,0.5-0.4,0.6-0.7c0.4,0,1-0.2,1.7-0.2c0.6,0,1.2,0.2,2,0.1c0,0.1,0,0.2,0.1,0.3c0.2,0.5,0.7,0.9,1.3,1c0.1,0,0.1,0,0.2,0c0.8-0.1,1.6-0.5,2.1-1.1l0,0c0.4-0.4,0.9-0.7,1.4-0.9c0.6-0.3,1-0.5,1.1-1C20.3,18.6,20.1,18.2,19.7,17.6z M12.8,4.8c0.6,0.1,1.1,0.6,1,1.2c0,0.3-0.1,0.6-0.3,0.9c0,0,0,0-0.1,0c-0.2-0.1-0.3-0.1-0.4-0.2c0.1-0.1,0.1-0.3,0.2-0.5c0-0.4-0.2-0.7-0.4-0.7c-0.3,0-0.5,0.3-0.5,0.7c0,0,0,0.1,0,0.1c-0.1-0.1-0.3-0.1-0.4-0.2c0,0,0-0.1,0-0.1C11.8,5.5,12.2,4.9,12.8,4.8z M12.5,6.8c0.1,0.1,0.3,0.2,0.4,0.2c0.1,0,0.3,0.1,0.4,0.2c0.2,0.1,0.4,0.2,0.4,0.5c0,0.3-0.3,0.6-0.9,0.8c-0.2,0.1-0.3,0.1-0.4,0.2c-0.3,0.2-0.6,0.3-1,0.3c-0.3,0-0.6-0.2-0.8-0.4c-0.1-0.1-0.2-0.2-0.4-0.3C10.1,8.2,9.9,8,9.8,7.7c0-0.1,0.1-0.2,0.2-0.3c0.3-0.2,0.4-0.3,0.5-0.4l0.1-0.1c0.2-0.3,0.6-0.5,1-0.5C11.9,6.5,12.2,6.6,12.5,6.8z M10.4,5c0.4,0,0.7,0.4,0.8,1.1c0,0.1,0,0.1,0,0.2c-0.1,0-0.3,0.1-0.4,0.2c0,0,0-0.1,0-0.2c0-0.3-0.2-0.6-0.4-0.5c-0.2,0-0.3,0.3-0.3,0.6c0,0.2,0.1,0.3,0.2,0.4l0,0c0,0-0.1,0.1-0.2,0.1C9.9,6.7,9.7,6.4,9.7,6.1C9.7,5.5,10,5,10.4,5z M9.4,21.1c-0.7,0.3-1.6,0.2-2.2-0.2c-0.6-0.3-1.1-0.4-1.8-0.4c-0.5-0.1-1-0.1-1.1-0.3c-0.1-0.2-0.1-0.5,0.1-1c0.1-0.3,0.1-0.6,0-0.9c-0.1-0.3-0.1-0.5,0-0.8C4.5,17.2,4.7,17.1,5,17c0.3-0.1,0.5-0.2,0.7-0.4c0.1-0.1,0.2-0.2,0.3-0.4c0.3-0.4,0.5-0.6,0.8-0.6c0.6,0.1,1.1,1,1.5,1.9c0.2,0.3,0.4,0.7,0.7,1c0.4,0.5,0.9,1.2,0.9,1.6C9.9,20.6,9.7,20.9,9.4,21.1z M14.3,18.9c0,0.1,0,0.1-0.1,0.2c-1.2,0.9-2.8,1-4.1,0.3c-0.2-0.3-0.4-0.6-0.6-0.9c0.9-0.1,0.7-1.3-1.2-2.5c-2-1.3-0.6-3.7,0.1-4.8c0.1-0.1,0.1,0-0.3,0.8c-0.3,0.6-0.9,2.1-0.1,3.2c0-0.8,0.2-1.6,0.5-2.4c0.7-1.3,1.2-2.8,1.5-4.3c0.1,0.1,0.1,0.1,0.2,0.1c0.1,0.1,0.2,0.2,0.3,0.2c0.2,0.3,0.6,0.4,0.9,0.4c0,0,0.1,0,0.1,0c0.4,0,0.8-0.1,1.1-0.4c0.1-0.1,0.2-0.2,0.4-0.2c0.3-0.1,0.6-0.3,0.9-0.6c0.4,1.3,0.8,2.5,1.4,3.6c0.4,0.8,0.7,1.6,0.9,2.5c0.3,0,0.7,0.1,1,0.3c0.8,0.4,1.1,0.7,1,1.2c-0.1,0-0.1,0-0.2,0c0-0.3-0.2-0.6-0.9-0.9c-0.7-0.3-1.3-0.3-1.5,0.4c-0.1,0-0.2,0.1-0.3,0.1c-0.8,0.4-0.8,1.5-0.9,2.6C14.5,18.2,14.4,18.5,14.3,18.9z M18.9,19.5c-0.6,0.2-1.1,0.6-1.5,1.1c-0.4,0.6-1.1,1-1.9,0.9c-0.4,0-0.8-0.3-0.9-0.7c-0.1-0.6-0.1-1.2,0.2-1.8c0.1-0.4,0.2-0.7,0.3-1.1c0.1-1.2,0.1-1.9,0.6-2.2h0c0,0.5,0.3,0.8,0.7,1c0.5,0,1-0.1,1.4-0.5c0.1,0,0.1,0,0.2,0c0.3,0,0.5,0,0.7,0.2c0.2,0.2,0.3,0.5,0.3,0.7c0,0.3,0.2,0.6,0.3,0.9c0.5,0.5,0.5,0.8,0.5,0.9C19.7,19.1,19.3,19.3,18.9,19.5z M9.9,7.5c-0.1,0-0.1,0-0.1,0.1c0,0,0,0.1,0.1,0.1c0,0,0,0,0,0c0.1,0,0.1,0.1,0.1,0.1c0.3,0.4,0.8,0.6,1.4,0.7c0.5-0.1,1-0.2,1.5-0.6c0.2-0.1,0.4-0.2,0.6-0.3c0.1,0,0.1-0.1,0.1-0.1c0-0.1,0-0.1-0.1-0.1l0,0c-0.2,0.1-0.5,0.2-0.7,0.3c-0.4,0.3-0.9,0.5-1.4,0.5c-0.5,0-0.9-0.3-1.2-0.6C10.1,7.6,10,7.5,9.9,7.5z"
|
||||
fill="currentColor"/>
|
||||
</svg>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Icon svg={<CustomIcon/>}/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default LinuxDoIcon;
|
||||
67
web/src/components/LinuxDoOAuth.js
Normal file
67
web/src/components/LinuxDoOAuth.js
Normal file
@@ -0,0 +1,67 @@
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { API, showError, showSuccess } from '../helpers';
|
||||
import { UserContext } from '../context/User';
|
||||
|
||||
const LinuxDoOAuth = () => {
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
const [prompt, setPrompt] = useState('处理中...');
|
||||
const [processing, setProcessing] = useState(true);
|
||||
|
||||
let navigate = useNavigate();
|
||||
|
||||
const sendCode = async (code, state, count) => {
|
||||
let aff = localStorage.getItem('aff');
|
||||
const res = await API.get(`/api/oauth/linuxdo?code=${code}&state=${state}&aff=${aff}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
if (message === 'bind') {
|
||||
showSuccess('绑定成功!');
|
||||
navigate('/setting');
|
||||
} else {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
showSuccess('登录成功!');
|
||||
navigate('/');
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
if (count === 0) {
|
||||
setPrompt(`操作失败,重定向至登录界面中...`);
|
||||
navigate('/setting'); // in case this is failed to bind GitHub
|
||||
return;
|
||||
}
|
||||
count++;
|
||||
setPrompt(`出现错误,第 ${count} 次重试中...`);
|
||||
await new Promise((resolve) => setTimeout(resolve, count * 2000));
|
||||
await sendCode(code, state, count);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let error = searchParams.get('error');
|
||||
if (error) {
|
||||
let errorDescription = searchParams.get('error_description');
|
||||
showError(`授权错误:${error}: ${errorDescription}`);
|
||||
navigate('/setting');
|
||||
return;
|
||||
}
|
||||
|
||||
let code = searchParams.get('code');
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Segment style={{ minHeight: '300px' }}>
|
||||
<Dimmer active inverted>
|
||||
<Loader size='large'>{prompt}</Loader>
|
||||
</Dimmer>
|
||||
</Segment>
|
||||
);
|
||||
};
|
||||
|
||||
export default LinuxDoOAuth;
|
||||
@@ -1,5 +1,5 @@
|
||||
import React from 'react';
|
||||
import { Segment, Dimmer, Loader } from 'semantic-ui-react';
|
||||
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
|
||||
|
||||
const Loading = ({ prompt: name = 'page' }) => {
|
||||
return (
|
||||
|
||||
@@ -1,259 +1,265 @@
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { UserContext } from '../context/User';
|
||||
import { API, getLogo, isMobile, showError, showInfo, showSuccess, showWarning } from '../helpers';
|
||||
import { onGitHubOAuthClicked } from './utils';
|
||||
import Turnstile from "react-turnstile";
|
||||
import { Layout, Card, Image, Form, Button, Divider, Modal } from "@douyinfe/semi-ui";
|
||||
import Title from "@douyinfe/semi-ui/lib/es/typography/title";
|
||||
import Text from "@douyinfe/semi-ui/lib/es/typography/text";
|
||||
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
|
||||
import { onGitHubOAuthClicked, onLinuxDoOAuthClicked } from './utils';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import { Button, Card, Divider, Form, Icon, Layout, Modal } from '@douyinfe/semi-ui';
|
||||
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
|
||||
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
||||
import TelegramLoginButton from 'react-telegram-login';
|
||||
|
||||
import { IconGithubLogo } from '@douyinfe/semi-icons';
|
||||
import LinuxDoIcon from './LinuxDoIcon';
|
||||
import WeChatIcon from './WeChatIcon';
|
||||
|
||||
const LoginForm = () => {
|
||||
const [inputs, setInputs] = useState({
|
||||
username: '',
|
||||
password: '',
|
||||
wechat_verification_code: ''
|
||||
});
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
const [submitted, setSubmitted] = useState(false);
|
||||
const { username, password } = inputs;
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
const [turnstileEnabled, setTurnstileEnabled] = useState(false);
|
||||
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
|
||||
const [turnstileToken, setTurnstileToken] = useState('');
|
||||
let navigate = useNavigate();
|
||||
const [status, setStatus] = useState({});
|
||||
const logo = getLogo();
|
||||
const [inputs, setInputs] = useState({
|
||||
username: '',
|
||||
password: '',
|
||||
wechat_verification_code: ''
|
||||
});
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
const [submitted, setSubmitted] = useState(false);
|
||||
const { username, password } = inputs;
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
const [turnstileEnabled, setTurnstileEnabled] = useState(false);
|
||||
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
|
||||
const [turnstileToken, setTurnstileToken] = useState('');
|
||||
let navigate = useNavigate();
|
||||
const [status, setStatus] = useState({});
|
||||
const logo = getLogo();
|
||||
|
||||
useEffect(() => {
|
||||
if (searchParams.get('expired')) {
|
||||
showError('未登录或登录已过期,请重新登录!');
|
||||
}
|
||||
let status = localStorage.getItem('status');
|
||||
if (status) {
|
||||
status = JSON.parse(status);
|
||||
setStatus(status);
|
||||
if (status.turnstile_check) {
|
||||
setTurnstileEnabled(true);
|
||||
setTurnstileSiteKey(status.turnstile_site_key);
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
|
||||
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
|
||||
|
||||
const onWeChatLoginClicked = () => {
|
||||
setShowWeChatLoginModal(true);
|
||||
};
|
||||
|
||||
const onSubmitWeChatVerificationCode = async () => {
|
||||
if (turnstileEnabled && turnstileToken === '') {
|
||||
showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!');
|
||||
return;
|
||||
}
|
||||
const res = await API.get(
|
||||
`/api/oauth/wechat?code=${inputs.wechat_verification_code}`
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
navigate('/');
|
||||
showSuccess('登录成功!');
|
||||
setShowWeChatLoginModal(false);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
function handleChange(name, value) {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
useEffect(() => {
|
||||
if (searchParams.get('expired')) {
|
||||
showError('未登录或登录已过期,请重新登录!');
|
||||
}
|
||||
|
||||
async function handleSubmit(e) {
|
||||
if (turnstileEnabled && turnstileToken === '') {
|
||||
showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!');
|
||||
return;
|
||||
}
|
||||
setSubmitted(true);
|
||||
if (username && password) {
|
||||
const res = await API.post(`/api/user/login?turnstile=${turnstileToken}`, {
|
||||
username,
|
||||
password
|
||||
});
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
showSuccess('登录成功!');
|
||||
if (username === 'root' && password === '123456') {
|
||||
Modal.error({ title: '您正在使用默认密码!', content: '请立刻修改默认密码!', centered: true });
|
||||
}
|
||||
navigate('/token');
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
} else {
|
||||
showError('请输入用户名和密码!');
|
||||
}
|
||||
let status = localStorage.getItem('status');
|
||||
if (status) {
|
||||
status = JSON.parse(status);
|
||||
setStatus(status);
|
||||
if (status.turnstile_check) {
|
||||
setTurnstileEnabled(true);
|
||||
setTurnstileSiteKey(status.turnstile_site_key);
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
|
||||
// 添加Telegram登录处理函数
|
||||
const onTelegramLoginClicked = async (response) => {
|
||||
const fields = ["id", "first_name", "last_name", "username", "photo_url", "auth_date", "hash", "lang"];
|
||||
const params = {};
|
||||
fields.forEach((field) => {
|
||||
if (response[field]) {
|
||||
params[field] = response[field];
|
||||
}
|
||||
});
|
||||
const res = await API.get(`/api/oauth/telegram/login`, { params });
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
showSuccess('登录成功!');
|
||||
navigate('/');
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Layout>
|
||||
<Layout.Header>
|
||||
</Layout.Header>
|
||||
<Layout.Content>
|
||||
<div style={{ justifyContent: 'center', display: "flex", marginTop: 120 }}>
|
||||
<div style={{ width: 500 }}>
|
||||
<Card>
|
||||
<Title heading={2} style={{ textAlign: 'center' }}>
|
||||
用户登录
|
||||
</Title>
|
||||
<Form>
|
||||
<Form.Input
|
||||
field={'username'}
|
||||
label={'用户名'}
|
||||
placeholder='用户名'
|
||||
name='username'
|
||||
onChange={(value) => handleChange('username', value)}
|
||||
/>
|
||||
<Form.Input
|
||||
field={'password'}
|
||||
label={'密码'}
|
||||
placeholder='密码'
|
||||
name='password'
|
||||
type='password'
|
||||
onChange={(value) => handleChange('password', value)}
|
||||
/>
|
||||
const onWeChatLoginClicked = () => {
|
||||
setShowWeChatLoginModal(true);
|
||||
};
|
||||
|
||||
<Button theme='solid' style={{ width: '100%' }} type={'primary'} size='large'
|
||||
htmlType={'submit'} onClick={handleSubmit}>
|
||||
登录
|
||||
</Button>
|
||||
</Form>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', marginTop: 20 }}>
|
||||
<Text>
|
||||
没有账号请先 <Link to='/register'>注册账号</Link>
|
||||
</Text>
|
||||
<Text>
|
||||
忘记密码 <Link to='/reset'>点击重置</Link>
|
||||
</Text>
|
||||
</div>
|
||||
{status.github_oauth || status.wechat_login || status.telegram_oauth ? (
|
||||
<>
|
||||
<Divider margin='12px' align='center'>
|
||||
第三方登录
|
||||
</Divider>
|
||||
<div style={{ display: 'flex', justifyContent: 'center', marginTop: 20 }}>
|
||||
{status.github_oauth ? (
|
||||
<Button
|
||||
type='primary'
|
||||
icon={<IconGithubLogo />}
|
||||
onClick={() => onGitHubOAuthClicked(status.github_client_id)}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
{/*{status.wechat_login ? (*/}
|
||||
{/* <Button*/}
|
||||
{/* circular*/}
|
||||
{/* color='green'*/}
|
||||
{/* icon='wechat'*/}
|
||||
{/* onClick={onWeChatLoginClicked}*/}
|
||||
{/* />*/}
|
||||
{/*) : (*/}
|
||||
{/* <></>*/}
|
||||
{/*)}*/}
|
||||
|
||||
{status.telegram_oauth ? (
|
||||
<TelegramLoginButton dataOnauth={onTelegramLoginClicked} botName={status.telegram_bot_name} />
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
{/*<Modal*/}
|
||||
{/* onClose={() => setShowWeChatLoginModal(false)}*/}
|
||||
{/* onOpen={() => setShowWeChatLoginModal(true)}*/}
|
||||
{/* open={showWeChatLoginModal}*/}
|
||||
{/* size={'mini'}*/}
|
||||
{/*>*/}
|
||||
{/* <Modal.Content>*/}
|
||||
{/* <Modal.Description>*/}
|
||||
{/* <Image src={status.wechat_qrcode} fluid/>*/}
|
||||
{/* <div style={{textAlign: 'center'}}>*/}
|
||||
{/* <p>*/}
|
||||
{/* 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)*/}
|
||||
{/* </p>*/}
|
||||
{/* </div>*/}
|
||||
{/* <Form size='large'>*/}
|
||||
{/* <Form.Input*/}
|
||||
{/* field={'wechat_verification_code'}*/}
|
||||
{/* placeholder='验证码'*/}
|
||||
{/* name='wechat_verification_code'*/}
|
||||
{/* value={inputs.wechat_verification_code}*/}
|
||||
{/* onChange={handleChange}*/}
|
||||
{/* />*/}
|
||||
{/* <Button*/}
|
||||
{/* color=''*/}
|
||||
{/* fluid*/}
|
||||
{/* size='large'*/}
|
||||
{/* onClick={onSubmitWeChatVerificationCode}*/}
|
||||
{/* >*/}
|
||||
{/* 登录*/}
|
||||
{/* </Button>*/}
|
||||
{/* </Form>*/}
|
||||
{/* </Modal.Description>*/}
|
||||
{/* </Modal.Content>*/}
|
||||
{/*</Modal>*/}
|
||||
</Card>
|
||||
{turnstileEnabled ? (
|
||||
<div style={{ display: 'flex', justifyContent: 'center', marginTop: 20 }}>
|
||||
<Turnstile
|
||||
sitekey={turnstileSiteKey}
|
||||
onVerify={(token) => {
|
||||
setTurnstileToken(token);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</Layout.Content>
|
||||
</Layout>
|
||||
</div>
|
||||
const onSubmitWeChatVerificationCode = async () => {
|
||||
if (turnstileEnabled && turnstileToken === '') {
|
||||
showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!');
|
||||
return;
|
||||
}
|
||||
const res = await API.get(
|
||||
`/api/oauth/wechat?code=${inputs.wechat_verification_code}`
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
navigate('/');
|
||||
showSuccess('登录成功!');
|
||||
setShowWeChatLoginModal(false);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
function handleChange(name, value) {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
}
|
||||
|
||||
async function handleSubmit(e) {
|
||||
if (turnstileEnabled && turnstileToken === '') {
|
||||
showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!');
|
||||
return;
|
||||
}
|
||||
setSubmitted(true);
|
||||
if (username && password) {
|
||||
const res = await API.post(`/api/user/login?turnstile=${turnstileToken}`, {
|
||||
username,
|
||||
password
|
||||
});
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
showSuccess('登录成功!');
|
||||
if (username === 'root' && password === '123456') {
|
||||
Modal.error({ title: '您正在使用默认密码!', content: '请立刻修改默认密码!', centered: true });
|
||||
}
|
||||
navigate('/token');
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
} else {
|
||||
showError('请输入用户名和密码!');
|
||||
}
|
||||
}
|
||||
|
||||
// 添加Telegram登录处理函数
|
||||
const onTelegramLoginClicked = async (response) => {
|
||||
const fields = ['id', 'first_name', 'last_name', 'username', 'photo_url', 'auth_date', 'hash', 'lang'];
|
||||
const params = {};
|
||||
fields.forEach((field) => {
|
||||
if (response[field]) {
|
||||
params[field] = response[field];
|
||||
}
|
||||
});
|
||||
const res = await API.get(`/api/oauth/telegram/login`, { params });
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
showSuccess('登录成功!');
|
||||
navigate('/');
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Layout>
|
||||
<Layout.Header>
|
||||
</Layout.Header>
|
||||
<Layout.Content>
|
||||
<div style={{ justifyContent: 'center', display: 'flex', marginTop: 120 }}>
|
||||
<div style={{ width: 500 }}>
|
||||
<Card>
|
||||
<Title heading={2} style={{ textAlign: 'center' }}>
|
||||
用户登录
|
||||
</Title>
|
||||
<Form>
|
||||
<Form.Input
|
||||
field={'username'}
|
||||
label={'用户名'}
|
||||
placeholder="用户名"
|
||||
name="username"
|
||||
onChange={(value) => handleChange('username', value)}
|
||||
/>
|
||||
<Form.Input
|
||||
field={'password'}
|
||||
label={'密码'}
|
||||
placeholder="密码"
|
||||
name="password"
|
||||
type="password"
|
||||
onChange={(value) => handleChange('password', value)}
|
||||
/>
|
||||
|
||||
<Button theme="solid" style={{ width: '100%' }} type={'primary'} size="large"
|
||||
htmlType={'submit'} onClick={handleSubmit}>
|
||||
登录
|
||||
</Button>
|
||||
</Form>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', marginTop: 20 }}>
|
||||
<Text>
|
||||
没有账号请先 <Link to="/register">注册账号</Link>
|
||||
</Text>
|
||||
<Text>
|
||||
忘记密码 <Link to="/reset">点击重置</Link>
|
||||
</Text>
|
||||
</div>
|
||||
{status.github_oauth || status.linuxdo_oauth || status.wechat_login || status.telegram_oauth ? (
|
||||
<>
|
||||
<Divider margin="12px" align="center">
|
||||
第三方登录
|
||||
</Divider>
|
||||
<div style={{ display: 'flex', justifyContent: 'center', marginTop: 20 }}>
|
||||
{status.github_oauth ? (
|
||||
<Button
|
||||
type="primary"
|
||||
icon={<IconGithubLogo />}
|
||||
onClick={() => onGitHubOAuthClicked(status.github_client_id)}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
{status.linuxdo_oauth ? (
|
||||
<Button
|
||||
type="primary"
|
||||
icon={<LinuxDoIcon />}
|
||||
style={{color: '#000'}}
|
||||
onClick={() => onLinuxDoOAuthClicked(status.linuxdo_client_id)}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
{status.wechat_login ? (
|
||||
<Button
|
||||
type="primary"
|
||||
style={{ color: 'rgba(var(--semi-green-5), 1)' }}
|
||||
icon={<Icon svg={<WeChatIcon />} />}
|
||||
onClick={onWeChatLoginClicked}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
|
||||
{status.telegram_oauth ? (
|
||||
<TelegramLoginButton dataOnauth={onTelegramLoginClicked} botName={status.telegram_bot_name} />
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
<Modal
|
||||
title="微信扫码登录"
|
||||
visible={showWeChatLoginModal}
|
||||
maskClosable={true}
|
||||
onOk={onSubmitWeChatVerificationCode}
|
||||
onCancel={() => setShowWeChatLoginModal(false)}
|
||||
okText={'登录'}
|
||||
size={'small'}
|
||||
centered={true}
|
||||
>
|
||||
<div style={{ display: 'flex', alignItem: 'center', flexDirection: 'column' }}>
|
||||
<img src={status.wechat_qrcode} />
|
||||
</div>
|
||||
<div style={{ textAlign: 'center' }}>
|
||||
<p>
|
||||
微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)
|
||||
</p>
|
||||
</div>
|
||||
<Form size="large">
|
||||
<Form.Input
|
||||
field={'wechat_verification_code'}
|
||||
placeholder="验证码"
|
||||
label={'验证码'}
|
||||
value={inputs.wechat_verification_code}
|
||||
onChange={(value) => handleChange('wechat_verification_code', value)}
|
||||
/>
|
||||
</Form>
|
||||
</Modal>
|
||||
</Card>
|
||||
{turnstileEnabled ? (
|
||||
<div style={{ display: 'flex', justifyContent: 'center', marginTop: 20 }}>
|
||||
<Turnstile
|
||||
sitekey={turnstileSiteKey}
|
||||
onVerify={(token) => {
|
||||
setTurnstileToken(token);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</Layout.Content>
|
||||
</Layout>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default LoginForm;
|
||||
|
||||
@@ -1,501 +1,399 @@
|
||||
import React, {useEffect, useState} from 'react';
|
||||
import {Label} from 'semantic-ui-react';
|
||||
import {API, copy, isAdmin, showError, showSuccess, timestamp2string} from '../helpers';
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers';
|
||||
|
||||
import {Table, Avatar, Tag, Form, Button, Layout, Select, Popover, Modal, Spin, Space} from '@douyinfe/semi-ui';
|
||||
import {ITEMS_PER_PAGE} from '../constants';
|
||||
import {renderNumber, renderQuota, stringToColor} from '../helpers/render';
|
||||
import {
|
||||
IconAt,
|
||||
IconHistogram,
|
||||
IconGift,
|
||||
IconKey,
|
||||
IconUser,
|
||||
IconLayers,
|
||||
IconSetting,
|
||||
IconCreditCard,
|
||||
IconSemiLogo,
|
||||
IconHome,
|
||||
IconMore
|
||||
} from '@douyinfe/semi-icons';
|
||||
import Paragraph from "@douyinfe/semi-ui/lib/es/typography/paragraph";
|
||||
import { Avatar, Button, Form, Layout, Modal, Select, Space, Spin, Table, Tag } from '@douyinfe/semi-ui';
|
||||
import { ITEMS_PER_PAGE } from '../constants';
|
||||
import { renderNumber, renderQuota, stringToColor } from '../helpers/render';
|
||||
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
|
||||
|
||||
const { Header } = Layout;
|
||||
|
||||
const {Header} = Layout;
|
||||
function renderTimestamp(timestamp) {
|
||||
return (
|
||||
<>
|
||||
{timestamp2string(timestamp)}
|
||||
</>
|
||||
);
|
||||
return (<>
|
||||
{timestamp2string(timestamp)}
|
||||
</>);
|
||||
}
|
||||
|
||||
const MODE_OPTIONS = [
|
||||
{key: 'all', text: '全部用户', value: 'all'},
|
||||
{key: 'self', text: '当前用户', value: 'self'}
|
||||
];
|
||||
const MODE_OPTIONS = [{ key: 'all', text: '全部用户', value: 'all' }, { key: 'self', text: '当前用户', value: 'self' }];
|
||||
|
||||
const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
|
||||
'light-blue', 'lime', 'orange', 'pink',
|
||||
'purple', 'red', 'teal', 'violet', 'yellow'
|
||||
]
|
||||
const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', 'light-blue', 'lime', 'orange', 'pink', 'purple', 'red', 'teal', 'violet', 'yellow'];
|
||||
|
||||
function renderType(type) {
|
||||
switch (type) {
|
||||
case 1:
|
||||
return <Tag color='cyan' size='large'> 充值 </Tag>;
|
||||
case 2:
|
||||
return <Tag color='lime' size='large'> 消费 </Tag>;
|
||||
case 3:
|
||||
return <Tag color='orange' size='large'> 管理 </Tag>;
|
||||
case 4:
|
||||
return <Tag color='purple' size='large'> 系统 </Tag>;
|
||||
default:
|
||||
return <Tag color='black' size='large'> 未知 </Tag>;
|
||||
}
|
||||
switch (type) {
|
||||
case 1:
|
||||
return <Tag color="cyan" size="large"> 充值 </Tag>;
|
||||
case 2:
|
||||
return <Tag color="lime" size="large"> 消费 </Tag>;
|
||||
case 3:
|
||||
return <Tag color="orange" size="large"> 管理 </Tag>;
|
||||
case 4:
|
||||
return <Tag color="purple" size="large"> 系统 </Tag>;
|
||||
default:
|
||||
return <Tag color="black" size="large"> 未知 </Tag>;
|
||||
}
|
||||
}
|
||||
|
||||
function renderIsStream(bool) {
|
||||
if (bool) {
|
||||
return <Tag color='blue' size='large'>流</Tag>;
|
||||
} else {
|
||||
return <Tag color='purple' size='large'>非流</Tag>;
|
||||
}
|
||||
if (bool) {
|
||||
return <Tag color="blue" size="large">流</Tag>;
|
||||
} else {
|
||||
return <Tag color="purple" size="large">非流</Tag>;
|
||||
}
|
||||
}
|
||||
|
||||
function renderUseTime(type) {
|
||||
const time = parseInt(type);
|
||||
if (time < 101) {
|
||||
return <Tag color='green' size='large'> {time} s </Tag>;
|
||||
} else if (time < 300) {
|
||||
return <Tag color='orange' size='large'> {time} s </Tag>;
|
||||
} else {
|
||||
return <Tag color='red' size='large'> {time} s </Tag>;
|
||||
}
|
||||
const time = parseInt(type);
|
||||
if (time < 101) {
|
||||
return <Tag color="green" size="large"> {time} s </Tag>;
|
||||
} else if (time < 300) {
|
||||
return <Tag color="orange" size="large"> {time} s </Tag>;
|
||||
} else {
|
||||
return <Tag color="red" size="large"> {time} s </Tag>;
|
||||
}
|
||||
}
|
||||
|
||||
const LogsTable = () => {
|
||||
const columns = [
|
||||
{
|
||||
title: '时间',
|
||||
dataIndex: 'timestamp2string',
|
||||
},
|
||||
{
|
||||
title: '渠道',
|
||||
dataIndex: 'channel',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
isAdminUser ?
|
||||
record.type === 0 || record.type === 2 ?
|
||||
<div>
|
||||
{<Tag color={colors[parseInt(text) % colors.length]} size='large'> {text} </Tag>}
|
||||
</div>
|
||||
:
|
||||
<></>
|
||||
:
|
||||
<></>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '用户',
|
||||
dataIndex: 'username',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
isAdminUser ?
|
||||
<div>
|
||||
<Avatar size="small" color={stringToColor(text)} style={{marginRight: 4}}
|
||||
onClick={() => showUserInfo(record.user_id)}>
|
||||
{typeof text === 'string' && text.slice(0, 1)}
|
||||
</Avatar>
|
||||
{text}
|
||||
</div>
|
||||
:
|
||||
<></>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '令牌',
|
||||
dataIndex: 'token_name',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
record.type === 0 || record.type === 2 ?
|
||||
<div>
|
||||
<Tag color='grey' size='large' onClick={() => {
|
||||
copyText(text)
|
||||
}}> {text} </Tag>
|
||||
</div>
|
||||
:
|
||||
<></>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '类型',
|
||||
dataIndex: 'type',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{renderType(text)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '模型',
|
||||
dataIndex: 'model_name',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
record.type === 0 || record.type === 2 ?
|
||||
<div>
|
||||
<Tag color={stringToColor(text)} size='large' onClick={() => {
|
||||
copyText(text)
|
||||
}}> {text} </Tag>
|
||||
</div>
|
||||
:
|
||||
<></>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '用时',
|
||||
dataIndex: 'use_time',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
<Space>
|
||||
{renderUseTime(text)}
|
||||
{renderIsStream(record.is_stream)}
|
||||
</Space>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '提示',
|
||||
dataIndex: 'prompt_tokens',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
record.type === 0 || record.type === 2 ?
|
||||
<div>
|
||||
{<span> {text} </span>}
|
||||
</div>
|
||||
:
|
||||
<></>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '补全',
|
||||
dataIndex: 'completion_tokens',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
parseInt(text) > 0 && (record.type === 0 || record.type === 2) ?
|
||||
<div>
|
||||
{<span> {text} </span>}
|
||||
</div>
|
||||
:
|
||||
<></>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '花费',
|
||||
dataIndex: 'quota',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
record.type === 0 || record.type === 2 ?
|
||||
<div>
|
||||
{
|
||||
renderQuota(text, 6)
|
||||
}
|
||||
</div>
|
||||
:
|
||||
<></>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '详情',
|
||||
dataIndex: 'content',
|
||||
render: (text, record, index) => {
|
||||
return <Paragraph ellipsis={{ rows: 2, showTooltip: { type: 'popover', opts: { style: { width: 240 } } } }} style={{ maxWidth: 240}}>
|
||||
{text}
|
||||
</Paragraph>
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
const [logs, setLogs] = useState([]);
|
||||
const [showStat, setShowStat] = useState(false);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [loadingStat, setLoadingStat] = useState(false);
|
||||
const [activePage, setActivePage] = useState(1);
|
||||
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
const [searching, setSearching] = useState(false);
|
||||
const [logType, setLogType] = useState(0);
|
||||
const isAdminUser = isAdmin();
|
||||
let now = new Date();
|
||||
// 初始化start_timestamp为前一天
|
||||
const [inputs, setInputs] = useState({
|
||||
username: '',
|
||||
token_name: '',
|
||||
model_name: '',
|
||||
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||
channel: ''
|
||||
});
|
||||
const {username, token_name, model_name, start_timestamp, end_timestamp, channel} = inputs;
|
||||
|
||||
const [stat, setStat] = useState({
|
||||
quota: 0,
|
||||
token: 0
|
||||
});
|
||||
|
||||
const handleInputChange = (value, name) => {
|
||||
setInputs((inputs) => ({...inputs, [name]: value}));
|
||||
};
|
||||
|
||||
const getLogSelfStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
setStat(data);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const getLogStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
setStat(data);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleEyeClick = async () => {
|
||||
setLoadingStat(true);
|
||||
if (isAdminUser) {
|
||||
await getLogStat();
|
||||
} else {
|
||||
await getLogSelfStat();
|
||||
}
|
||||
setShowStat(true);
|
||||
setLoadingStat(false);
|
||||
};
|
||||
|
||||
const showUserInfo = async (userId) => {
|
||||
if (!isAdminUser) {
|
||||
return;
|
||||
}
|
||||
const res = await API.get(`/api/user/${userId}`);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
Modal.info({
|
||||
title: '用户信息',
|
||||
content: <div style={{padding: 12}}>
|
||||
<p>用户名: {data.username}</p>
|
||||
<p>余额: {renderQuota(data.quota)}</p>
|
||||
<p>已用额度:{renderQuota(data.used_quota)}</p>
|
||||
<p>请求次数:{renderNumber(data.request_count)}</p>
|
||||
</div>,
|
||||
centered: true,
|
||||
})
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const setLogsFormat = (logs) => {
|
||||
for (let i = 0; i < logs.length; i++) {
|
||||
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||
logs[i].key = '' + logs[i].id;
|
||||
}
|
||||
// data.key = '' + data.id
|
||||
setLogs(logs);
|
||||
setLogCount(logs.length + ITEMS_PER_PAGE);
|
||||
// console.log(logCount);
|
||||
const columns = [{
|
||||
title: '时间', dataIndex: 'timestamp2string'
|
||||
}, {
|
||||
title: '渠道',
|
||||
dataIndex: 'channel',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
return (isAdminUser ? record.type === 0 || record.type === 2 ? <div>
|
||||
{<Tag color={colors[parseInt(text) % colors.length]} size="large"> {text} </Tag>}
|
||||
</div> : <></> : <></>);
|
||||
}
|
||||
|
||||
const loadLogs = async (startIdx) => {
|
||||
setLoading(true);
|
||||
|
||||
let url = '';
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
if (isAdminUser) {
|
||||
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||
} else {
|
||||
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
}
|
||||
const res = await API.get(url);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
if (startIdx === 0) {
|
||||
setLogsFormat(data);
|
||||
} else {
|
||||
let newLogs = [...logs];
|
||||
newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
|
||||
setLogsFormat(newLogs);
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE);
|
||||
|
||||
const handlePageChange = page => {
|
||||
setActivePage(page);
|
||||
if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
|
||||
// In this case we have to load more data and then append them.
|
||||
loadLogs(page - 1).then(r => {
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const refresh = async () => {
|
||||
// setLoading(true);
|
||||
setActivePage(1);
|
||||
await loadLogs(0);
|
||||
};
|
||||
|
||||
const copyText = async (text) => {
|
||||
if (await copy(text)) {
|
||||
showSuccess('已复制:' + text);
|
||||
} else {
|
||||
// setSearchKeyword(text);
|
||||
Modal.error({title: '无法复制到剪贴板,请手动复制', content: text});
|
||||
}
|
||||
}, {
|
||||
title: '用户',
|
||||
dataIndex: 'username',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
return (isAdminUser ? <div>
|
||||
<Avatar size="small" color={stringToColor(text)} style={{ marginRight: 4 }}
|
||||
onClick={() => showUserInfo(record.user_id)}>
|
||||
{typeof text === 'string' && text.slice(0, 1)}
|
||||
</Avatar>
|
||||
{text}
|
||||
</div> : <></>);
|
||||
}
|
||||
}, {
|
||||
title: '令牌', dataIndex: 'token_name', render: (text, record, index) => {
|
||||
return (record.type === 0 || record.type === 2 ? <div>
|
||||
<Tag color="grey" size="large" onClick={() => {
|
||||
copyText(text);
|
||||
}}> {text} </Tag>
|
||||
</div> : <></>);
|
||||
}
|
||||
}, {
|
||||
title: '类型', dataIndex: 'type', render: (text, record, index) => {
|
||||
return (<div>
|
||||
{renderType(text)}
|
||||
</div>);
|
||||
}
|
||||
}, {
|
||||
title: '模型', dataIndex: 'model_name', render: (text, record, index) => {
|
||||
return (record.type === 0 || record.type === 2 ? <div>
|
||||
<Tag color={stringToColor(text)} size="large" onClick={() => {
|
||||
copyText(text);
|
||||
}}> {text} </Tag>
|
||||
</div> : <></>);
|
||||
}
|
||||
}, {
|
||||
title: '用时', dataIndex: 'use_time', render: (text, record, index) => {
|
||||
return (<div>
|
||||
<Space>
|
||||
{renderUseTime(text)}
|
||||
{renderIsStream(record.is_stream)}
|
||||
</Space>
|
||||
</div>);
|
||||
}
|
||||
}, {
|
||||
title: '提示', dataIndex: 'prompt_tokens', render: (text, record, index) => {
|
||||
return (record.type === 0 || record.type === 2 ? <div>
|
||||
{<span> {text} </span>}
|
||||
</div> : <></>);
|
||||
}
|
||||
}, {
|
||||
title: '补全', dataIndex: 'completion_tokens', render: (text, record, index) => {
|
||||
return (parseInt(text) > 0 && (record.type === 0 || record.type === 2) ? <div>
|
||||
{<span> {text} </span>}
|
||||
</div> : <></>);
|
||||
}
|
||||
}, {
|
||||
title: '花费', dataIndex: 'quota', render: (text, record, index) => {
|
||||
return (record.type === 0 || record.type === 2 ? <div>
|
||||
{renderQuota(text, 6)}
|
||||
</div> : <></>);
|
||||
}
|
||||
}, {
|
||||
title: '详情', dataIndex: 'content', render: (text, record, index) => {
|
||||
return <Paragraph ellipsis={{ rows: 2, showTooltip: { type: 'popover', opts: { style: { width: 240 } } } }}
|
||||
style={{ maxWidth: 240 }}>
|
||||
{text}
|
||||
</Paragraph>;
|
||||
}
|
||||
}];
|
||||
|
||||
useEffect(() => {
|
||||
refresh().then();
|
||||
}, [logType]);
|
||||
const [logs, setLogs] = useState([]);
|
||||
const [showStat, setShowStat] = useState(false);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [loadingStat, setLoadingStat] = useState(false);
|
||||
const [activePage, setActivePage] = useState(1);
|
||||
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
|
||||
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
const [searching, setSearching] = useState(false);
|
||||
const [logType, setLogType] = useState(0);
|
||||
const isAdminUser = isAdmin();
|
||||
let now = new Date();
|
||||
// 初始化start_timestamp为前一天
|
||||
const [inputs, setInputs] = useState({
|
||||
username: '',
|
||||
token_name: '',
|
||||
model_name: '',
|
||||
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||
channel: ''
|
||||
});
|
||||
const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
|
||||
|
||||
const searchLogs = async () => {
|
||||
if (searchKeyword === '') {
|
||||
// if keyword is blank, load files instead.
|
||||
await loadLogs(0);
|
||||
setActivePage(1);
|
||||
return;
|
||||
}
|
||||
setSearching(true);
|
||||
const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
setLogs(data);
|
||||
setActivePage(1);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setSearching(false);
|
||||
};
|
||||
const [stat, setStat] = useState({
|
||||
quota: 0, token: 0
|
||||
});
|
||||
|
||||
const handleKeywordChange = async (e, {value}) => {
|
||||
setSearchKeyword(value.trim());
|
||||
};
|
||||
const handleInputChange = (value, name) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
};
|
||||
|
||||
const sortLog = (key) => {
|
||||
if (logs.length === 0) return;
|
||||
setLoading(true);
|
||||
let sortedLogs = [...logs];
|
||||
if (typeof sortedLogs[0][key] === 'string') {
|
||||
sortedLogs.sort((a, b) => {
|
||||
return ('' + a[key]).localeCompare(b[key]);
|
||||
});
|
||||
} else {
|
||||
sortedLogs.sort((a, b) => {
|
||||
if (a[key] === b[key]) return 0;
|
||||
if (a[key] > b[key]) return -1;
|
||||
if (a[key] < b[key]) return 1;
|
||||
});
|
||||
}
|
||||
if (sortedLogs[0].id === logs[0].id) {
|
||||
sortedLogs.reverse();
|
||||
}
|
||||
setLogs(sortedLogs);
|
||||
setLoading(false);
|
||||
};
|
||||
const getLogSelfStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setStat(data);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
const getLogStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setStat(data);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleEyeClick = async () => {
|
||||
setLoadingStat(true);
|
||||
if (isAdminUser) {
|
||||
await getLogStat();
|
||||
} else {
|
||||
await getLogSelfStat();
|
||||
}
|
||||
setShowStat(true);
|
||||
setLoadingStat(false);
|
||||
};
|
||||
|
||||
const showUserInfo = async (userId) => {
|
||||
if (!isAdminUser) {
|
||||
return;
|
||||
}
|
||||
const res = await API.get(`/api/user/${userId}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
Modal.info({
|
||||
title: '用户信息', content: <div style={{ padding: 12 }}>
|
||||
<p>用户名: {data.username}</p>
|
||||
<p>余额: {renderQuota(data.quota)}</p>
|
||||
<p>已用额度:{renderQuota(data.used_quota)}</p>
|
||||
<p>请求次数:{renderNumber(data.request_count)}</p>
|
||||
</div>, centered: true
|
||||
});
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const setLogsFormat = (logs) => {
|
||||
for (let i = 0; i < logs.length; i++) {
|
||||
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||
logs[i].key = '' + logs[i].id;
|
||||
}
|
||||
// data.key = '' + data.id
|
||||
setLogs(logs);
|
||||
setLogCount(logs.length + ITEMS_PER_PAGE);
|
||||
// console.log(logCount);
|
||||
};
|
||||
|
||||
const loadLogs = async (startIdx, pageSize, logType = 0) => {
|
||||
setLoading(true);
|
||||
|
||||
let url = '';
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
if (isAdminUser) {
|
||||
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||
} else {
|
||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
}
|
||||
const res = await API.get(url);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
if (startIdx === 0) {
|
||||
setLogsFormat(data);
|
||||
} else {
|
||||
let newLogs = [...logs];
|
||||
newLogs.splice(startIdx * pageSize, data.length, ...data);
|
||||
setLogsFormat(newLogs);
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const pageData = logs.slice((activePage - 1) * pageSize, activePage * pageSize);
|
||||
|
||||
const handlePageChange = page => {
|
||||
setActivePage(page);
|
||||
if (page === Math.ceil(logs.length / pageSize) + 1) {
|
||||
// In this case we have to load more data and then append them.
|
||||
loadLogs(page - 1, pageSize, logType).then(r => {
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handlePageSizeChange = async (size) => {
|
||||
localStorage.setItem('page-size', size + '');
|
||||
setPageSize(size);
|
||||
setActivePage(1);
|
||||
loadLogs(0, size)
|
||||
.then()
|
||||
.catch((reason) => {
|
||||
showError(reason);
|
||||
});
|
||||
};
|
||||
|
||||
const refresh = async (localLogType) => {
|
||||
// setLoading(true);
|
||||
setActivePage(1);
|
||||
await loadLogs(0, pageSize, localLogType);
|
||||
};
|
||||
|
||||
const copyText = async (text) => {
|
||||
if (await copy(text)) {
|
||||
showSuccess('已复制:' + text);
|
||||
} else {
|
||||
// setSearchKeyword(text);
|
||||
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
// console.log('default effect')
|
||||
const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE;
|
||||
setPageSize(localPageSize);
|
||||
loadLogs(0, localPageSize)
|
||||
.then()
|
||||
.catch((reason) => {
|
||||
showError(reason);
|
||||
});
|
||||
}, []);
|
||||
|
||||
const searchLogs = async () => {
|
||||
if (searchKeyword === '') {
|
||||
// if keyword is blank, load files instead.
|
||||
await loadLogs(0, pageSize);
|
||||
setActivePage(1);
|
||||
return;
|
||||
}
|
||||
setSearching(true);
|
||||
const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setLogs(data);
|
||||
setActivePage(1);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setSearching(false);
|
||||
};
|
||||
|
||||
return (<>
|
||||
<Layout>
|
||||
<Header>
|
||||
<Spin spinning={loadingStat}>
|
||||
<h3>使用明细(总消耗额度:
|
||||
<span onClick={handleEyeClick} style={{
|
||||
cursor: 'pointer', color: 'gray'
|
||||
}}>{showStat ? renderQuota(stat.quota) : '点击查看'}</span>
|
||||
)
|
||||
</h3>
|
||||
</Spin>
|
||||
</Header>
|
||||
<Form layout="horizontal" style={{ marginTop: 10 }}>
|
||||
<>
|
||||
<Layout>
|
||||
<Header>
|
||||
<Spin spinning={loadingStat}>
|
||||
<h3>使用明细(总消耗额度:
|
||||
<span onClick={handleEyeClick} style={{cursor: 'pointer', color: 'gray'}}>{showStat?renderQuota(stat.quota):"点击查看"}</span>
|
||||
)
|
||||
</h3>
|
||||
</Spin>
|
||||
</Header>
|
||||
<Form layout='horizontal' style={{marginTop: 10}}>
|
||||
<>
|
||||
<Form.Input field="token_name" label='令牌名称' style={{width: 176}} value={token_name}
|
||||
placeholder={'可选值'} name='token_name'
|
||||
onChange={value => handleInputChange(value, 'token_name')}/>
|
||||
<Form.Input field="model_name" label='模型名称' style={{width: 176}} value={model_name}
|
||||
placeholder='可选值'
|
||||
name='model_name'
|
||||
onChange={value => handleInputChange(value, 'model_name')}/>
|
||||
<Form.DatePicker field="start_timestamp" label='起始时间' style={{width: 272}}
|
||||
initValue={start_timestamp}
|
||||
value={start_timestamp} type='dateTime'
|
||||
name='start_timestamp'
|
||||
onChange={value => handleInputChange(value, 'start_timestamp')}/>
|
||||
<Form.DatePicker field="end_timestamp" fluid label='结束时间' style={{width: 272}}
|
||||
initValue={end_timestamp}
|
||||
value={end_timestamp} type='dateTime'
|
||||
name='end_timestamp'
|
||||
onChange={value => handleInputChange(value, 'end_timestamp')}/>
|
||||
{
|
||||
isAdminUser && <>
|
||||
<Form.Input field="channel" label='渠道 ID' style={{width: 176}} value={channel}
|
||||
placeholder='可选值' name='channel'
|
||||
onChange={value => handleInputChange(value, 'channel')}/>
|
||||
<Form.Input field="username" label='用户名称' style={{width: 176}} value={username}
|
||||
placeholder={'可选值'} name='username'
|
||||
onChange={value => handleInputChange(value, 'username')}/>
|
||||
</>
|
||||
}
|
||||
<Form.Section>
|
||||
<Button label='查询' type="primary" htmlType="submit" className="btn-margin-right"
|
||||
onClick={refresh} loading={loading}>查询</Button>
|
||||
</Form.Section>
|
||||
</>
|
||||
</Form>
|
||||
<Table style={{marginTop: 5}} columns={columns} dataSource={pageData} pagination={{
|
||||
currentPage: activePage,
|
||||
pageSize: ITEMS_PER_PAGE,
|
||||
total: logCount,
|
||||
pageSizeOpts: [10, 20, 50, 100],
|
||||
onPageChange: handlePageChange,
|
||||
}}/>
|
||||
<Select defaultValue="0" style={{width: 120}} onChange={
|
||||
(value) => {
|
||||
setLogType(parseInt(value));
|
||||
}
|
||||
}>
|
||||
<Select.Option value="0">全部</Select.Option>
|
||||
<Select.Option value="1">充值</Select.Option>
|
||||
<Select.Option value="2">消费</Select.Option>
|
||||
<Select.Option value="3">管理</Select.Option>
|
||||
<Select.Option value="4">系统</Select.Option>
|
||||
</Select>
|
||||
</Layout>
|
||||
<Form.Input field="token_name" label="令牌名称" style={{ width: 176 }} value={token_name}
|
||||
placeholder={'可选值'} name="token_name"
|
||||
onChange={value => handleInputChange(value, 'token_name')} />
|
||||
<Form.Input field="model_name" label="模型名称" style={{ width: 176 }} value={model_name}
|
||||
placeholder="可选值"
|
||||
name="model_name"
|
||||
onChange={value => handleInputChange(value, 'model_name')} />
|
||||
<Form.DatePicker field="start_timestamp" label="起始时间" style={{ width: 272 }}
|
||||
initValue={start_timestamp}
|
||||
value={start_timestamp} type="dateTime"
|
||||
name="start_timestamp"
|
||||
onChange={value => handleInputChange(value, 'start_timestamp')} />
|
||||
<Form.DatePicker field="end_timestamp" fluid label="结束时间" style={{ width: 272 }}
|
||||
initValue={end_timestamp}
|
||||
value={end_timestamp} type="dateTime"
|
||||
name="end_timestamp"
|
||||
onChange={value => handleInputChange(value, 'end_timestamp')} />
|
||||
{isAdminUser && <>
|
||||
<Form.Input field="channel" label="渠道 ID" style={{ width: 176 }} value={channel}
|
||||
placeholder="可选值" name="channel"
|
||||
onChange={value => handleInputChange(value, 'channel')} />
|
||||
<Form.Input field="username" label="用户名称" style={{ width: 176 }} value={username}
|
||||
placeholder={'可选值'} name="username"
|
||||
onChange={value => handleInputChange(value, 'username')} />
|
||||
</>}
|
||||
<Form.Section>
|
||||
<Button label="查询" type="primary" htmlType="submit" className="btn-margin-right"
|
||||
onClick={refresh} loading={loading}>查询</Button>
|
||||
</Form.Section>
|
||||
</>
|
||||
);
|
||||
</Form>
|
||||
<Table style={{ marginTop: 5 }} columns={columns} dataSource={pageData} pagination={{
|
||||
currentPage: activePage,
|
||||
pageSize: pageSize,
|
||||
total: logCount,
|
||||
pageSizeOpts: [10, 20, 50, 100],
|
||||
showSizeChanger: true,
|
||||
onPageSizeChange: (size) => {
|
||||
handlePageSizeChange(size).then();
|
||||
},
|
||||
onPageChange: handlePageChange
|
||||
}} />
|
||||
<Select defaultValue="0" style={{ width: 120 }} onChange={(value) => {
|
||||
setLogType(parseInt(value));
|
||||
refresh(parseInt(value)).then();
|
||||
}}>
|
||||
<Select.Option value="0">全部</Select.Option>
|
||||
<Select.Option value="1">充值</Select.Option>
|
||||
<Select.Option value="2">消费</Select.Option>
|
||||
<Select.Option value="3">管理</Select.Option>
|
||||
<Select.Option value="4">系统</Select.Option>
|
||||
</Select>
|
||||
</Layout>
|
||||
</>);
|
||||
};
|
||||
|
||||
export default LogsTable;
|
||||
|
||||
@@ -1,430 +1,454 @@
|
||||
import React, {useEffect, useState} from 'react';
|
||||
import {API, copy, isAdmin, showError, showSuccess, timestamp2string} from '../helpers';
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers';
|
||||
|
||||
import {
|
||||
Table,
|
||||
Avatar,
|
||||
Tag,
|
||||
Form,
|
||||
Button,
|
||||
Layout,
|
||||
Select,
|
||||
Popover,
|
||||
Modal,
|
||||
ImagePreview,
|
||||
Typography, Progress
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {ITEMS_PER_PAGE} from '../constants';
|
||||
import {renderNumber, renderQuota, stringToColor} from '../helpers/render';
|
||||
import { Banner, Button, Form, ImagePreview, Layout, Modal, Progress, Table, Tag, Typography } from '@douyinfe/semi-ui';
|
||||
import { ITEMS_PER_PAGE } from '../constants';
|
||||
|
||||
|
||||
const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
|
||||
'light-blue', 'lime', 'orange', 'pink',
|
||||
'purple', 'red', 'teal', 'violet', 'yellow'
|
||||
]
|
||||
'light-blue', 'lime', 'orange', 'pink',
|
||||
'purple', 'red', 'teal', 'violet', 'yellow'
|
||||
];
|
||||
|
||||
function renderType(type) {
|
||||
switch (type) {
|
||||
case 'IMAGINE':
|
||||
return <Tag color="blue" size='large'>绘图</Tag>;
|
||||
case 'UPSCALE':
|
||||
return <Tag color="orange" size='large'>放大</Tag>;
|
||||
case 'VARIATION':
|
||||
return <Tag color="purple" size='large'>变换</Tag>;
|
||||
case 'DESCRIBE':
|
||||
return <Tag color="yellow" size='large'>图生文</Tag>;
|
||||
case 'BLEAND':
|
||||
return <Tag color="lime" size='large'>图混合</Tag>;
|
||||
default:
|
||||
return <Tag color="white" size='large'>未知</Tag>;
|
||||
}
|
||||
switch (type) {
|
||||
case 'IMAGINE':
|
||||
return <Tag color="blue" size="large">绘图</Tag>;
|
||||
case 'UPSCALE':
|
||||
return <Tag color="orange" size="large">放大</Tag>;
|
||||
case 'VARIATION':
|
||||
return <Tag color="purple" size="large">变换</Tag>;
|
||||
case 'HIGH_VARIATION':
|
||||
return <Tag color="purple" size="large">强变换</Tag>;
|
||||
case 'LOW_VARIATION':
|
||||
return <Tag color="purple" size="large">弱变换</Tag>;
|
||||
case 'PAN':
|
||||
return <Tag color="cyan" size="large">平移</Tag>;
|
||||
case 'DESCRIBE':
|
||||
return <Tag color="yellow" size="large">图生文</Tag>;
|
||||
case 'BLEND':
|
||||
return <Tag color="lime" size="large">图混合</Tag>;
|
||||
case 'SHORTEN':
|
||||
return <Tag color="pink" size="large">缩词</Tag>;
|
||||
case 'REROLL':
|
||||
return <Tag color="indigo" size="large">重绘</Tag>;
|
||||
case 'INPAINT':
|
||||
return <Tag color="violet" size="large">局部重绘-提交</Tag>;
|
||||
case 'ZOOM':
|
||||
return <Tag color="teal" size="large">变焦</Tag>;
|
||||
case 'CUSTOM_ZOOM':
|
||||
return <Tag color="teal" size="large">自定义变焦-提交</Tag>;
|
||||
case 'MODAL':
|
||||
return <Tag color="green" size="large">窗口处理</Tag>;
|
||||
case 'SWAP_FACE':
|
||||
return <Tag color="light-green" size="large">换脸</Tag>;
|
||||
default:
|
||||
return <Tag color="white" size="large">未知</Tag>;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function renderCode(code) {
|
||||
switch (code) {
|
||||
case 1:
|
||||
return <Tag color="green" size='large'>已提交</Tag>;
|
||||
case 21:
|
||||
return <Tag color="lime" size='large'>排队中</Tag>;
|
||||
case 22:
|
||||
return <Tag color="orange" size='large'>重复提交</Tag>;
|
||||
default:
|
||||
return <Tag color="white" size='large'>未知</Tag>;
|
||||
}
|
||||
switch (code) {
|
||||
case 1:
|
||||
return <Tag color="green" size="large">已提交</Tag>;
|
||||
case 21:
|
||||
return <Tag color="lime" size="large">等待中</Tag>;
|
||||
case 22:
|
||||
return <Tag color="orange" size="large">重复提交</Tag>;
|
||||
case 0:
|
||||
return <Tag color="yellow" size="large">未提交</Tag>;
|
||||
default:
|
||||
return <Tag color="white" size="large">未知</Tag>;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function renderStatus(type) {
|
||||
// Ensure all cases are string literals by adding quotes.
|
||||
switch (type) {
|
||||
case 'SUCCESS':
|
||||
return <Tag color="green" size='large'>成功</Tag>;
|
||||
case 'NOT_START':
|
||||
return <Tag color="grey" size='large'>未启动</Tag>;
|
||||
case 'SUBMITTED':
|
||||
return <Tag color="yellow" size='large'>队列中</Tag>;
|
||||
case 'IN_PROGRESS':
|
||||
return <Tag color="blue" size='large'>执行中</Tag>;
|
||||
case 'FAILURE':
|
||||
return <Tag color="red" size='large'>失败</Tag>;
|
||||
default:
|
||||
return <Tag color="white" size='large'>未知</Tag>;
|
||||
}
|
||||
// Ensure all cases are string literals by adding quotes.
|
||||
switch (type) {
|
||||
case 'SUCCESS':
|
||||
return <Tag color="green" size="large">成功</Tag>;
|
||||
case 'NOT_START':
|
||||
return <Tag color="grey" size="large">未启动</Tag>;
|
||||
case 'SUBMITTED':
|
||||
return <Tag color="yellow" size="large">队列中</Tag>;
|
||||
case 'IN_PROGRESS':
|
||||
return <Tag color="blue" size="large">执行中</Tag>;
|
||||
case 'FAILURE':
|
||||
return <Tag color="red" size="large">失败</Tag>;
|
||||
case 'MODAL':
|
||||
return <Tag color="yellow" size="large">窗口等待</Tag>;
|
||||
default:
|
||||
return <Tag color="white" size="large">未知</Tag>;
|
||||
}
|
||||
}
|
||||
|
||||
const renderTimestamp = (timestampInSeconds) => {
|
||||
const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒
|
||||
const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒
|
||||
|
||||
const year = date.getFullYear(); // 获取年份
|
||||
const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数
|
||||
const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数
|
||||
const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数
|
||||
const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数
|
||||
const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数
|
||||
const year = date.getFullYear(); // 获取年份
|
||||
const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数
|
||||
const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数
|
||||
const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数
|
||||
const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数
|
||||
const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数
|
||||
|
||||
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
|
||||
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
|
||||
};
|
||||
|
||||
|
||||
const LogsTable = () => {
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
const [modalContent, setModalContent] = useState('');
|
||||
const columns = [
|
||||
{
|
||||
title: '提交时间',
|
||||
dataIndex: 'submit_time',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{renderTimestamp(text / 1000)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '渠道',
|
||||
dataIndex: 'channel_id',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
const [modalContent, setModalContent] = useState('');
|
||||
const columns = [
|
||||
{
|
||||
title: '提交时间',
|
||||
dataIndex: 'submit_time',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{renderTimestamp(text / 1000)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '渠道',
|
||||
dataIndex: 'channel_id',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
|
||||
<div>
|
||||
<Tag color={colors[parseInt(text) % colors.length]} size='large' onClick={() => {
|
||||
copyText(text); // 假设copyText是用于文本复制的函数
|
||||
}}> {text} </Tag>
|
||||
</div>
|
||||
<div>
|
||||
<Tag color={colors[parseInt(text) % colors.length]} size="large" onClick={() => {
|
||||
copyText(text); // 假设copyText是用于文本复制的函数
|
||||
}}> {text} </Tag>
|
||||
</div>
|
||||
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '类型',
|
||||
dataIndex: 'action',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{renderType(text)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '任务ID',
|
||||
dataIndex: 'mj_id',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{text}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '提交结果',
|
||||
dataIndex: 'code',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{renderCode(text)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '任务状态',
|
||||
dataIndex: 'status',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{renderStatus(text)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '进度',
|
||||
dataIndex: 'progress',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{
|
||||
// 转换例如100%为数字100,如果text未定义,返回0
|
||||
<Progress stroke={record.status === "FAILURE"?"var(--semi-color-warning)":null} percent={text ? parseInt(text.replace('%', '')) : 0} showInfo={true}
|
||||
aria-label="drawing progress"/>
|
||||
}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '结果图片',
|
||||
dataIndex: 'image_url',
|
||||
render: (text, record, index) => {
|
||||
if (!text) {
|
||||
return '无';
|
||||
}
|
||||
return (
|
||||
<Button
|
||||
onClick={() => {
|
||||
setModalImageUrl(text); // 更新图片URL状态
|
||||
setIsModalOpenurl(true); // 打开模态框
|
||||
}}
|
||||
>
|
||||
查看图片
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: 'Prompt',
|
||||
dataIndex: 'prompt',
|
||||
render: (text, record, index) => {
|
||||
// 如果text未定义,返回替代文本,例如空字符串''或其他
|
||||
if (!text) {
|
||||
return '无';
|
||||
}
|
||||
|
||||
return (
|
||||
<Typography.Text
|
||||
ellipsis={{showTooltip: true}}
|
||||
style={{width: 100}}
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Typography.Text>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: 'PromptEn',
|
||||
dataIndex: 'prompt_en',
|
||||
render: (text, record, index) => {
|
||||
// 如果text未定义,返回替代文本,例如空字符串''或其他
|
||||
if (!text) {
|
||||
return '无';
|
||||
}
|
||||
|
||||
return (
|
||||
<Typography.Text
|
||||
ellipsis={{showTooltip: true}}
|
||||
style={{width: 100}}
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Typography.Text>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '失败原因',
|
||||
dataIndex: 'fail_reason',
|
||||
render: (text, record, index) => {
|
||||
// 如果text未定义,返回替代文本,例如空字符串''或其他
|
||||
if (!text) {
|
||||
return '无';
|
||||
}
|
||||
|
||||
return (
|
||||
<Typography.Text
|
||||
ellipsis={{showTooltip: true}}
|
||||
style={{width: 100}}
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Typography.Text>
|
||||
);
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '类型',
|
||||
dataIndex: 'action',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{renderType(text)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '任务ID',
|
||||
dataIndex: 'mj_id',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{text}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '提交结果',
|
||||
dataIndex: 'code',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{renderCode(text)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '任务状态',
|
||||
dataIndex: 'status',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{renderStatus(text)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '进度',
|
||||
dataIndex: 'progress',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{
|
||||
// 转换例如100%为数字100,如果text未定义,返回0
|
||||
<Progress stroke={record.status === 'FAILURE' ? 'var(--semi-color-warning)' : null}
|
||||
percent={text ? parseInt(text.replace('%', '')) : 0} showInfo={true}
|
||||
aria-label="drawing progress" />
|
||||
}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '结果图片',
|
||||
dataIndex: 'image_url',
|
||||
render: (text, record, index) => {
|
||||
if (!text) {
|
||||
return '无';
|
||||
}
|
||||
return (
|
||||
<Button
|
||||
onClick={() => {
|
||||
setModalImageUrl(text); // 更新图片URL状态
|
||||
setIsModalOpenurl(true); // 打开模态框
|
||||
}}
|
||||
>
|
||||
查看图片
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: 'Prompt',
|
||||
dataIndex: 'prompt',
|
||||
render: (text, record, index) => {
|
||||
// 如果text未定义,返回替代文本,例如空字符串''或其他
|
||||
if (!text) {
|
||||
return '无';
|
||||
}
|
||||
|
||||
];
|
||||
|
||||
const [logs, setLogs] = useState([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [activePage, setActivePage] = useState(1);
|
||||
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
|
||||
const [logType, setLogType] = useState(0);
|
||||
const isAdminUser = isAdmin();
|
||||
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
|
||||
|
||||
// 定义模态框图片URL的状态和更新函数
|
||||
const [modalImageUrl, setModalImageUrl] = useState('');
|
||||
let now = new Date();
|
||||
// 初始化start_timestamp为前一天
|
||||
const [inputs, setInputs] = useState({
|
||||
channel_id: '',
|
||||
mj_id: '',
|
||||
start_timestamp: timestamp2string(now.getTime() / 1000 - 2592000),
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||
});
|
||||
const {channel_id, mj_id, start_timestamp, end_timestamp} = inputs;
|
||||
|
||||
const [stat, setStat] = useState({
|
||||
quota: 0,
|
||||
token: 0
|
||||
});
|
||||
|
||||
const handleInputChange = (value, name) => {
|
||||
setInputs((inputs) => ({...inputs, [name]: value}));
|
||||
};
|
||||
|
||||
|
||||
const setLogsFormat = (logs) => {
|
||||
for (let i = 0; i < logs.length; i++) {
|
||||
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||
logs[i].key = '' + logs[i].id;
|
||||
return (
|
||||
<Typography.Text
|
||||
ellipsis={{ showTooltip: true }}
|
||||
style={{ width: 100 }}
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Typography.Text>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: 'PromptEn',
|
||||
dataIndex: 'prompt_en',
|
||||
render: (text, record, index) => {
|
||||
// 如果text未定义,返回替代文本,例如空字符串''或其他
|
||||
if (!text) {
|
||||
return '无';
|
||||
}
|
||||
// data.key = '' + data.id
|
||||
setLogs(logs);
|
||||
setLogCount(logs.length + ITEMS_PER_PAGE);
|
||||
// console.log(logCount);
|
||||
|
||||
return (
|
||||
<Typography.Text
|
||||
ellipsis={{ showTooltip: true }}
|
||||
style={{ width: 100 }}
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Typography.Text>
|
||||
);
|
||||
}
|
||||
},
|
||||
{
|
||||
title: '失败原因',
|
||||
dataIndex: 'fail_reason',
|
||||
render: (text, record, index) => {
|
||||
// 如果text未定义,返回替代文本,例如空字符串''或其他
|
||||
if (!text) {
|
||||
return '无';
|
||||
}
|
||||
|
||||
return (
|
||||
<Typography.Text
|
||||
ellipsis={{ showTooltip: true }}
|
||||
style={{ width: 100 }}
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Typography.Text>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const loadLogs = async (startIdx) => {
|
||||
setLoading(true);
|
||||
];
|
||||
|
||||
let url = '';
|
||||
let localStartTimestamp = Date.parse(start_timestamp);
|
||||
let localEndTimestamp = Date.parse(end_timestamp);
|
||||
if (isAdminUser) {
|
||||
url = `/api/mj/?p=${startIdx}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
} else {
|
||||
url = `/api/mj/self/?p=${startIdx}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
}
|
||||
const res = await API.get(url);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
if (startIdx === 0) {
|
||||
setLogsFormat(data);
|
||||
} else {
|
||||
let newLogs = [...logs];
|
||||
newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
|
||||
setLogsFormat(newLogs);
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
const [logs, setLogs] = useState([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [activePage, setActivePage] = useState(1);
|
||||
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
|
||||
const [logType, setLogType] = useState(0);
|
||||
const isAdminUser = isAdmin();
|
||||
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
|
||||
const [showBanner, setShowBanner] = useState(false);
|
||||
|
||||
const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE);
|
||||
// 定义模态框图片URL的状态和更新函数
|
||||
const [modalImageUrl, setModalImageUrl] = useState('');
|
||||
let now = new Date();
|
||||
// 初始化start_timestamp为前一天
|
||||
const [inputs, setInputs] = useState({
|
||||
channel_id: '',
|
||||
mj_id: '',
|
||||
start_timestamp: timestamp2string(now.getTime() / 1000 - 2592000),
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
|
||||
});
|
||||
const { channel_id, mj_id, start_timestamp, end_timestamp } = inputs;
|
||||
|
||||
const handlePageChange = page => {
|
||||
setActivePage(page);
|
||||
if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
|
||||
// In this case we have to load more data and then append them.
|
||||
loadLogs(page - 1).then(r => {
|
||||
});
|
||||
}
|
||||
};
|
||||
const [stat, setStat] = useState({
|
||||
quota: 0,
|
||||
token: 0
|
||||
});
|
||||
|
||||
const refresh = async () => {
|
||||
// setLoading(true);
|
||||
setActivePage(1);
|
||||
await loadLogs(0);
|
||||
};
|
||||
const handleInputChange = (value, name) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
};
|
||||
|
||||
const copyText = async (text) => {
|
||||
if (await copy(text)) {
|
||||
showSuccess('已复制:' + text);
|
||||
} else {
|
||||
// setSearchKeyword(text);
|
||||
Modal.error({title: '无法复制到剪贴板,请手动复制', content: text});
|
||||
}
|
||||
|
||||
const setLogsFormat = (logs) => {
|
||||
for (let i = 0; i < logs.length; i++) {
|
||||
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||
logs[i].key = '' + logs[i].id;
|
||||
}
|
||||
// data.key = '' + data.id
|
||||
setLogs(logs);
|
||||
setLogCount(logs.length + ITEMS_PER_PAGE);
|
||||
// console.log(logCount);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refresh().then();
|
||||
}, [logType]);
|
||||
const loadLogs = async (startIdx) => {
|
||||
setLoading(true);
|
||||
|
||||
let url = '';
|
||||
let localStartTimestamp = Date.parse(start_timestamp);
|
||||
let localEndTimestamp = Date.parse(end_timestamp);
|
||||
if (isAdminUser) {
|
||||
url = `/api/mj/?p=${startIdx}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
} else {
|
||||
url = `/api/mj/self/?p=${startIdx}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
}
|
||||
const res = await API.get(url);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
if (startIdx === 0) {
|
||||
setLogsFormat(data);
|
||||
} else {
|
||||
let newLogs = [...logs];
|
||||
newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
|
||||
setLogsFormat(newLogs);
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE);
|
||||
|
||||
<Layout>
|
||||
<Form layout='horizontal' style={{marginTop: 10}}>
|
||||
<>
|
||||
<Form.Input field="channel_id" label='渠道 ID' style={{width: 176}} value={channel_id}
|
||||
placeholder={'可选值'} name='channel_id'
|
||||
onChange={value => handleInputChange(value, 'channel_id')}/>
|
||||
<Form.Input field="mj_id" label='任务 ID' style={{width: 176}} value={mj_id}
|
||||
placeholder='可选值'
|
||||
name='mj_id'
|
||||
onChange={value => handleInputChange(value, 'mj_id')}/>
|
||||
<Form.DatePicker field="start_timestamp" label='起始时间' style={{width: 272}}
|
||||
initValue={start_timestamp}
|
||||
value={start_timestamp} type='dateTime'
|
||||
name='start_timestamp'
|
||||
onChange={value => handleInputChange(value, 'start_timestamp')}/>
|
||||
<Form.DatePicker field="end_timestamp" fluid label='结束时间' style={{width: 272}}
|
||||
initValue={end_timestamp}
|
||||
value={end_timestamp} type='dateTime'
|
||||
name='end_timestamp'
|
||||
onChange={value => handleInputChange(value, 'end_timestamp')}/>
|
||||
const handlePageChange = page => {
|
||||
setActivePage(page);
|
||||
if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
|
||||
// In this case we have to load more data and then append them.
|
||||
loadLogs(page - 1).then(r => {
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
<Form.Section>
|
||||
<Button label='查询' type="primary" htmlType="submit" className="btn-margin-right"
|
||||
onClick={refresh}>查询</Button>
|
||||
</Form.Section>
|
||||
</>
|
||||
</Form>
|
||||
<Table style={{marginTop: 5}} columns={columns} dataSource={pageData} pagination={{
|
||||
currentPage: activePage,
|
||||
pageSize: ITEMS_PER_PAGE,
|
||||
total: logCount,
|
||||
pageSizeOpts: [10, 20, 50, 100],
|
||||
onPageChange: handlePageChange,
|
||||
}} loading={loading}/>
|
||||
<Modal
|
||||
visible={isModalOpen}
|
||||
onOk={() => setIsModalOpen(false)}
|
||||
onCancel={() => setIsModalOpen(false)}
|
||||
closable={null}
|
||||
bodyStyle={{height: '400px', overflow: 'auto'}} // 设置模态框内容区域样式
|
||||
width={800} // 设置模态框宽度
|
||||
>
|
||||
<p style={{whiteSpace: 'pre-line'}}>{modalContent}</p>
|
||||
</Modal>
|
||||
<ImagePreview
|
||||
src={modalImageUrl}
|
||||
visible={isModalOpenurl}
|
||||
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
|
||||
/>
|
||||
const refresh = async () => {
|
||||
// setLoading(true);
|
||||
setActivePage(1);
|
||||
await loadLogs(0);
|
||||
};
|
||||
|
||||
</Layout>
|
||||
</>
|
||||
);
|
||||
const copyText = async (text) => {
|
||||
if (await copy(text)) {
|
||||
showSuccess('已复制:' + text);
|
||||
} else {
|
||||
// setSearchKeyword(text);
|
||||
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refresh().then();
|
||||
}, [logType]);
|
||||
|
||||
useEffect(() => {
|
||||
const mjNotifyEnabled = localStorage.getItem('mj_notify_enabled');
|
||||
if (mjNotifyEnabled !== 'true') {
|
||||
setShowBanner(true);
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
<Layout>
|
||||
{isAdminUser && showBanner ? <Banner
|
||||
type="info"
|
||||
description="当前未开启Midjourney回调,部分项目可能无法获得绘图结果,可在运营设置中开启。"
|
||||
/> : <></>
|
||||
}
|
||||
<Form layout="horizontal" style={{ marginTop: 10 }}>
|
||||
<>
|
||||
<Form.Input field="channel_id" label="渠道 ID" style={{ width: 176 }} value={channel_id}
|
||||
placeholder={'可选值'} name="channel_id"
|
||||
onChange={value => handleInputChange(value, 'channel_id')} />
|
||||
<Form.Input field="mj_id" label="任务 ID" style={{ width: 176 }} value={mj_id}
|
||||
placeholder="可选值"
|
||||
name="mj_id"
|
||||
onChange={value => handleInputChange(value, 'mj_id')} />
|
||||
<Form.DatePicker field="start_timestamp" label="起始时间" style={{ width: 272 }}
|
||||
initValue={start_timestamp}
|
||||
value={start_timestamp} type="dateTime"
|
||||
name="start_timestamp"
|
||||
onChange={value => handleInputChange(value, 'start_timestamp')} />
|
||||
<Form.DatePicker field="end_timestamp" fluid label="结束时间" style={{ width: 272 }}
|
||||
initValue={end_timestamp}
|
||||
value={end_timestamp} type="dateTime"
|
||||
name="end_timestamp"
|
||||
onChange={value => handleInputChange(value, 'end_timestamp')} />
|
||||
|
||||
<Form.Section>
|
||||
<Button label="查询" type="primary" htmlType="submit" className="btn-margin-right"
|
||||
onClick={refresh}>查询</Button>
|
||||
</Form.Section>
|
||||
</>
|
||||
</Form>
|
||||
<Table style={{ marginTop: 5 }} columns={columns} dataSource={pageData} pagination={{
|
||||
currentPage: activePage,
|
||||
pageSize: ITEMS_PER_PAGE,
|
||||
total: logCount,
|
||||
pageSizeOpts: [10, 20, 50, 100],
|
||||
onPageChange: handlePageChange
|
||||
}} loading={loading} />
|
||||
<Modal
|
||||
visible={isModalOpen}
|
||||
onOk={() => setIsModalOpen(false)}
|
||||
onCancel={() => setIsModalOpen(false)}
|
||||
closable={null}
|
||||
bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式
|
||||
width={800} // 设置模态框宽度
|
||||
>
|
||||
<p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
|
||||
</Modal>
|
||||
<ImagePreview
|
||||
src={modalImageUrl}
|
||||
visible={isModalOpenurl}
|
||||
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
|
||||
/>
|
||||
|
||||
</Layout>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default LogsTable;
|
||||
|
||||
@@ -1,453 +1,468 @@
|
||||
import React, {useEffect, useState} from 'react';
|
||||
import {Divider, Form, Grid, Header} from 'semantic-ui-react';
|
||||
import {API, showError, showSuccess, timestamp2string, verifyJSON} from '../helpers';
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
|
||||
import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
|
||||
|
||||
const OperationSetting = () => {
|
||||
let now = new Date();
|
||||
let [inputs, setInputs] = useState({
|
||||
QuotaForNewUser: 0,
|
||||
QuotaForInviter: 0,
|
||||
QuotaForInvitee: 0,
|
||||
QuotaRemindThreshold: 0,
|
||||
PreConsumedQuota: 0,
|
||||
ModelRatio: '',
|
||||
ModelPrice: '',
|
||||
GroupRatio: '',
|
||||
TopUpLink: '',
|
||||
ChatLink: '',
|
||||
ChatLink2: '', // 添加的新状态变量
|
||||
QuotaPerUnit: 0,
|
||||
AutomaticDisableChannelEnabled: '',
|
||||
AutomaticEnableChannelEnabled: '',
|
||||
ChannelDisableThreshold: 0,
|
||||
LogConsumeEnabled: '',
|
||||
DisplayInCurrencyEnabled: '',
|
||||
DisplayTokenStatEnabled: '',
|
||||
DrawingEnabled: '',
|
||||
DataExportEnabled: '',
|
||||
DataExportDefaultTime: 'hour',
|
||||
DataExportInterval: 5,
|
||||
DefaultCollapseSidebar: '', // 默认折叠侧边栏
|
||||
RetryTimes: 0
|
||||
let now = new Date();
|
||||
let [inputs, setInputs] = useState({
|
||||
QuotaForNewUser: 0,
|
||||
QuotaForInviter: 0,
|
||||
QuotaForInvitee: 0,
|
||||
QuotaRemindThreshold: 0,
|
||||
PreConsumedQuota: 0,
|
||||
ModelRatio: '',
|
||||
ModelPrice: '',
|
||||
GroupRatio: '',
|
||||
TopUpLink: '',
|
||||
ChatLink: '',
|
||||
ChatLink2: '', // 添加的新状态变量
|
||||
QuotaPerUnit: 0,
|
||||
AutomaticDisableChannelEnabled: '',
|
||||
AutomaticEnableChannelEnabled: '',
|
||||
ChannelDisableThreshold: 0,
|
||||
LogConsumeEnabled: '',
|
||||
DisplayInCurrencyEnabled: '',
|
||||
DisplayTokenStatEnabled: '',
|
||||
MjNotifyEnabled: '',
|
||||
DrawingEnabled: '',
|
||||
DataExportEnabled: '',
|
||||
DataExportDefaultTime: 'hour',
|
||||
DataExportInterval: 5,
|
||||
DefaultCollapseSidebar: '', // 默认折叠侧边栏
|
||||
RetryTimes: 0
|
||||
});
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
let [loading, setLoading] = useState(false);
|
||||
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
|
||||
// 精确时间选项(小时,天,周)
|
||||
const timeOptions = [
|
||||
{ key: 'hour', text: '小时', value: 'hour' },
|
||||
{ key: 'day', text: '天', value: 'day' },
|
||||
{ key: 'week', text: '周', value: 'week' }
|
||||
];
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'ModelPrice') {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
}
|
||||
newInputs[item.key] = item.value;
|
||||
});
|
||||
setInputs(newInputs);
|
||||
setOriginInputs(newInputs);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
getOptions().then();
|
||||
}, []);
|
||||
|
||||
const updateOption = async (key, value) => {
|
||||
setLoading(true);
|
||||
if (key.endsWith('Enabled')) {
|
||||
value = inputs[key] === 'true' ? 'false' : 'true';
|
||||
}
|
||||
if (key === 'DefaultCollapseSidebar') {
|
||||
value = inputs[key] === 'true' ? 'false' : 'true';
|
||||
}
|
||||
console.log(key, value);
|
||||
const res = await API.put('/api/option/', {
|
||||
key,
|
||||
value
|
||||
});
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
let [loading, setLoading] = useState(false);
|
||||
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
|
||||
// 精确时间选项(小时,天,周)
|
||||
const timeOptions = [
|
||||
{key: 'hour', text: '小时', value: 'hour'},
|
||||
{key: 'day', text: '天', value: 'day'},
|
||||
{key: 'week', text: '周', value: 'week'}
|
||||
];
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'ModelPrice') {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
}
|
||||
newInputs[item.key] = item.value;
|
||||
});
|
||||
setInputs(newInputs);
|
||||
setOriginInputs(newInputs);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
setInputs((inputs) => ({ ...inputs, [key]: value }));
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
getOptions().then();
|
||||
}, []);
|
||||
const handleInputChange = async (e, { name, value }) => {
|
||||
if (name.endsWith('Enabled') || name === 'DataExportInterval' || name === 'DataExportDefaultTime' || name === 'DefaultCollapseSidebar') {
|
||||
if (name === 'DataExportDefaultTime') {
|
||||
localStorage.setItem('data_export_default_time', value);
|
||||
} else if (name === 'MjNotifyEnabled') {
|
||||
localStorage.setItem('mj_notify_enabled', value);
|
||||
}
|
||||
await updateOption(name, value);
|
||||
} else {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
}
|
||||
};
|
||||
|
||||
const updateOption = async (key, value) => {
|
||||
setLoading(true);
|
||||
if (key.endsWith('Enabled')) {
|
||||
value = inputs[key] === 'true' ? 'false' : 'true';
|
||||
const submitConfig = async (group) => {
|
||||
switch (group) {
|
||||
case 'monitor':
|
||||
if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
|
||||
await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
|
||||
}
|
||||
if (key === 'DefaultCollapseSidebar') {
|
||||
value = inputs[key] === 'true' ? 'false' : 'true';
|
||||
if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
|
||||
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
|
||||
}
|
||||
console.log(key, value)
|
||||
const res = await API.put('/api/option/', {
|
||||
key,
|
||||
value
|
||||
});
|
||||
const {success, message} = res.data;
|
||||
if (success) {
|
||||
setInputs((inputs) => ({...inputs, [key]: value}));
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const handleInputChange = async (e, {name, value}) => {
|
||||
if (name.endsWith('Enabled') || name === 'DataExportInterval' || name === 'DataExportDefaultTime' || name === 'DefaultCollapseSidebar') {
|
||||
if (name === 'DataExportDefaultTime') {
|
||||
localStorage.setItem('data_export_default_time', value);
|
||||
}
|
||||
await updateOption(name, value);
|
||||
} else {
|
||||
setInputs((inputs) => ({...inputs, [name]: value}));
|
||||
}
|
||||
};
|
||||
|
||||
const submitConfig = async (group) => {
|
||||
switch (group) {
|
||||
case 'monitor':
|
||||
if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
|
||||
await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
|
||||
}
|
||||
if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
|
||||
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
|
||||
}
|
||||
break;
|
||||
case 'ratio':
|
||||
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
|
||||
if (!verifyJSON(inputs.ModelRatio)) {
|
||||
showError('模型倍率不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('ModelRatio', inputs.ModelRatio);
|
||||
}
|
||||
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
|
||||
if (!verifyJSON(inputs.GroupRatio)) {
|
||||
showError('分组倍率不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('GroupRatio', inputs.GroupRatio);
|
||||
}
|
||||
if (originInputs['ModelPrice'] !== inputs.ModelPrice) {
|
||||
if (!verifyJSON(inputs.ModelPrice)) {
|
||||
showError('模型固定价格不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('ModelPrice', inputs.ModelPrice);
|
||||
}
|
||||
break;
|
||||
case 'quota':
|
||||
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
|
||||
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
|
||||
}
|
||||
if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
|
||||
await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
|
||||
}
|
||||
if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
|
||||
await updateOption('QuotaForInviter', inputs.QuotaForInviter);
|
||||
}
|
||||
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
|
||||
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
|
||||
}
|
||||
break;
|
||||
case 'general':
|
||||
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
|
||||
await updateOption('TopUpLink', inputs.TopUpLink);
|
||||
}
|
||||
if (originInputs['ChatLink'] !== inputs.ChatLink) {
|
||||
await updateOption('ChatLink', inputs.ChatLink);
|
||||
}
|
||||
if (originInputs['ChatLink2'] !== inputs.ChatLink2) {
|
||||
await updateOption('ChatLink2', inputs.ChatLink2);
|
||||
}
|
||||
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
|
||||
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
|
||||
}
|
||||
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
|
||||
await updateOption('RetryTimes', inputs.RetryTimes);
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
const deleteHistoryLogs = async () => {
|
||||
console.log(inputs);
|
||||
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
showSuccess(`${data} 条日志已清理!`);
|
||||
break;
|
||||
case 'ratio':
|
||||
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
|
||||
if (!verifyJSON(inputs.ModelRatio)) {
|
||||
showError('模型倍率不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('ModelRatio', inputs.ModelRatio);
|
||||
}
|
||||
showError('日志清理失败:' + message);
|
||||
};
|
||||
return (
|
||||
<Grid columns={1}>
|
||||
<Grid.Column>
|
||||
<Form loading={loading}>
|
||||
<Header as='h3'>
|
||||
通用设置
|
||||
</Header>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input
|
||||
label='充值链接'
|
||||
name='TopUpLink'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.TopUpLink}
|
||||
type='link'
|
||||
placeholder='例如发卡网站的购买链接'
|
||||
/>
|
||||
<Form.Input
|
||||
label='默认聊天页面链接'
|
||||
name='ChatLink'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.ChatLink}
|
||||
type='link'
|
||||
placeholder='例如 ChatGPT Next Web 的部署地址'
|
||||
/>
|
||||
<Form.Input
|
||||
label='聊天页面2链接'
|
||||
name='ChatLink2'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.ChatLink2}
|
||||
type='link'
|
||||
placeholder='例如 ChatGPT Web & Midjourney 的部署地址'
|
||||
/>
|
||||
<Form.Input
|
||||
label='单位美元额度'
|
||||
name='QuotaPerUnit'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.QuotaPerUnit}
|
||||
type='number'
|
||||
step='0.01'
|
||||
placeholder='一单位货币能兑换的额度'
|
||||
/>
|
||||
<Form.Input
|
||||
label='失败重试次数'
|
||||
name='RetryTimes'
|
||||
type={'number'}
|
||||
step='1'
|
||||
min='0'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.RetryTimes}
|
||||
placeholder='失败重试次数'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group inline>
|
||||
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
|
||||
if (!verifyJSON(inputs.GroupRatio)) {
|
||||
showError('分组倍率不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('GroupRatio', inputs.GroupRatio);
|
||||
}
|
||||
if (originInputs['ModelPrice'] !== inputs.ModelPrice) {
|
||||
if (!verifyJSON(inputs.ModelPrice)) {
|
||||
showError('模型固定价格不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('ModelPrice', inputs.ModelPrice);
|
||||
}
|
||||
break;
|
||||
case 'quota':
|
||||
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
|
||||
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
|
||||
}
|
||||
if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
|
||||
await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
|
||||
}
|
||||
if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
|
||||
await updateOption('QuotaForInviter', inputs.QuotaForInviter);
|
||||
}
|
||||
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
|
||||
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
|
||||
}
|
||||
break;
|
||||
case 'general':
|
||||
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
|
||||
await updateOption('TopUpLink', inputs.TopUpLink);
|
||||
}
|
||||
if (originInputs['ChatLink'] !== inputs.ChatLink) {
|
||||
await updateOption('ChatLink', inputs.ChatLink);
|
||||
}
|
||||
if (originInputs['ChatLink2'] !== inputs.ChatLink2) {
|
||||
await updateOption('ChatLink2', inputs.ChatLink2);
|
||||
}
|
||||
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
|
||||
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
|
||||
}
|
||||
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
|
||||
await updateOption('RetryTimes', inputs.RetryTimes);
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
<Form.Checkbox
|
||||
checked={inputs.DisplayInCurrencyEnabled === 'true'}
|
||||
label='以货币形式显示额度'
|
||||
name='DisplayInCurrencyEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DisplayTokenStatEnabled === 'true'}
|
||||
label='Billing 相关 API 显示令牌额度而非用户额度'
|
||||
name='DisplayTokenStatEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DrawingEnabled === 'true'}
|
||||
label='启用绘图功能'
|
||||
name='DrawingEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DefaultCollapseSidebar === 'true'}
|
||||
label='默认折叠侧边栏'
|
||||
name='DefaultCollapseSidebar'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
submitConfig('general').then();
|
||||
}}>保存通用设置</Form.Button><Divider/>
|
||||
<Header as='h3'>
|
||||
日志设置
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.LogConsumeEnabled === 'true'}
|
||||
label='启用额度消费日志记录'
|
||||
name='LogConsumeEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
|
||||
name='history_timestamp'
|
||||
onChange={(e, {name, value}) => {
|
||||
setHistoryTimestamp(value);
|
||||
}}/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
deleteHistoryLogs().then();
|
||||
}}>清理历史日志</Form.Button>
|
||||
<Divider/>
|
||||
<Header as='h3'>
|
||||
数据看板
|
||||
</Header>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DataExportEnabled === 'true'}
|
||||
label='启用数据看板(实验性)'
|
||||
name='DataExportEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Group>
|
||||
<Form.Input
|
||||
label='数据看板更新间隔(分钟,设置过短会影响数据库性能)'
|
||||
name='DataExportInterval'
|
||||
type={'number'}
|
||||
step='1'
|
||||
min='1'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.DataExportInterval}
|
||||
placeholder='数据看板更新间隔(分钟,设置过短会影响数据库性能)'
|
||||
/>
|
||||
<Form.Select
|
||||
label='数据看板默认时间粒度(仅修改展示粒度,统计精确到小时)'
|
||||
options={timeOptions}
|
||||
name='DataExportDefaultTime'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.DataExportDefaultTime}
|
||||
placeholder='数据看板默认时间粒度'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider/>
|
||||
<Header as='h3'>
|
||||
监控设置
|
||||
</Header>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Input
|
||||
label='最长响应时间'
|
||||
name='ChannelDisableThreshold'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.ChannelDisableThreshold}
|
||||
type='number'
|
||||
min='0'
|
||||
placeholder='单位秒,当运行通道全部测试时,超过此时间将自动禁用通道'
|
||||
/>
|
||||
<Form.Input
|
||||
label='额度提醒阈值'
|
||||
name='QuotaRemindThreshold'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.QuotaRemindThreshold}
|
||||
type='number'
|
||||
min='0'
|
||||
placeholder='低于此额度时将发送邮件提醒用户'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.AutomaticDisableChannelEnabled === 'true'}
|
||||
label='失败时自动禁用通道'
|
||||
name='AutomaticDisableChannelEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.AutomaticEnableChannelEnabled === 'true'}
|
||||
label='成功时自动启用通道'
|
||||
name='AutomaticEnableChannelEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
submitConfig('monitor').then();
|
||||
}}>保存监控设置</Form.Button>
|
||||
<Divider/>
|
||||
<Header as='h3'>
|
||||
额度设置
|
||||
</Header>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input
|
||||
label='新用户初始额度'
|
||||
name='QuotaForNewUser'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.QuotaForNewUser}
|
||||
type='number'
|
||||
min='0'
|
||||
placeholder='例如:100'
|
||||
/>
|
||||
<Form.Input
|
||||
label='请求预扣费额度'
|
||||
name='PreConsumedQuota'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.PreConsumedQuota}
|
||||
type='number'
|
||||
min='0'
|
||||
placeholder='请求结束后多退少补'
|
||||
/>
|
||||
<Form.Input
|
||||
label='邀请新用户奖励额度'
|
||||
name='QuotaForInviter'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.QuotaForInviter}
|
||||
type='number'
|
||||
min='0'
|
||||
placeholder='例如:2000'
|
||||
/>
|
||||
<Form.Input
|
||||
label='新用户使用邀请码奖励额度'
|
||||
name='QuotaForInvitee'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.QuotaForInvitee}
|
||||
type='number'
|
||||
min='0'
|
||||
placeholder='例如:1000'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
submitConfig('quota').then();
|
||||
}}>保存额度设置</Form.Button>
|
||||
<Divider/>
|
||||
<Header as='h3'>
|
||||
倍率设置
|
||||
</Header>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)'
|
||||
name='ModelPrice'
|
||||
onChange={handleInputChange}
|
||||
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
|
||||
autoComplete='new-password'
|
||||
value={inputs.ModelPrice}
|
||||
placeholder='为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1,一次消耗0.1刀'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='模型倍率'
|
||||
name='ModelRatio'
|
||||
onChange={handleInputChange}
|
||||
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
|
||||
autoComplete='new-password'
|
||||
value={inputs.ModelRatio}
|
||||
placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='分组倍率'
|
||||
name='GroupRatio'
|
||||
onChange={handleInputChange}
|
||||
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
|
||||
autoComplete='new-password'
|
||||
value={inputs.GroupRatio}
|
||||
placeholder='为一个 JSON 文本,键为分组名称,值为倍率'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
submitConfig('ratio').then();
|
||||
}}>保存倍率设置</Form.Button>
|
||||
</Form>
|
||||
</Grid.Column>
|
||||
</Grid>
|
||||
)
|
||||
;
|
||||
const deleteHistoryLogs = async () => {
|
||||
console.log(inputs);
|
||||
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
showSuccess(`${data} 条日志已清理!`);
|
||||
return;
|
||||
}
|
||||
showError('日志清理失败:' + message);
|
||||
};
|
||||
return (
|
||||
<Grid columns={1}>
|
||||
<Grid.Column>
|
||||
<Form loading={loading}>
|
||||
<Header as="h3">
|
||||
通用设置
|
||||
</Header>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input
|
||||
label="充值链接"
|
||||
name="TopUpLink"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.TopUpLink}
|
||||
type="link"
|
||||
placeholder="例如发卡网站的购买链接"
|
||||
/>
|
||||
<Form.Input
|
||||
label="默认聊天页面链接"
|
||||
name="ChatLink"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.ChatLink}
|
||||
type="link"
|
||||
placeholder="例如 ChatGPT Next Web 的部署地址"
|
||||
/>
|
||||
<Form.Input
|
||||
label="聊天页面2链接"
|
||||
name="ChatLink2"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.ChatLink2}
|
||||
type="link"
|
||||
placeholder="例如 ChatGPT Web & Midjourney 的部署地址"
|
||||
/>
|
||||
<Form.Input
|
||||
label="单位美元额度"
|
||||
name="QuotaPerUnit"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.QuotaPerUnit}
|
||||
type="number"
|
||||
step="0.01"
|
||||
placeholder="一单位货币能兑换的额度"
|
||||
/>
|
||||
<Form.Input
|
||||
label="失败重试次数"
|
||||
name="RetryTimes"
|
||||
type={'number'}
|
||||
step="1"
|
||||
min="0"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.RetryTimes}
|
||||
placeholder="失败重试次数"
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DisplayInCurrencyEnabled === 'true'}
|
||||
label="以货币形式显示额度"
|
||||
name="DisplayInCurrencyEnabled"
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DisplayTokenStatEnabled === 'true'}
|
||||
label="Billing 相关 API 显示令牌额度而非用户额度"
|
||||
name="DisplayTokenStatEnabled"
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DefaultCollapseSidebar === 'true'}
|
||||
label="默认折叠侧边栏"
|
||||
name="DefaultCollapseSidebar"
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
submitConfig('general').then();
|
||||
}}>保存通用设置</Form.Button>
|
||||
<Divider />
|
||||
<Header as="h3">
|
||||
绘图设置
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DrawingEnabled === 'true'}
|
||||
label="启用绘图功能"
|
||||
name="DrawingEnabled"
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.MjNotifyEnabled === 'true'}
|
||||
label="允许回调(会泄露服务器ip地址)"
|
||||
name="MjNotifyEnabled"
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider />
|
||||
<Header as="h3">
|
||||
日志设置
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.LogConsumeEnabled === 'true'}
|
||||
label="启用额度消费日志记录"
|
||||
name="LogConsumeEnabled"
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input label="目标时间" value={historyTimestamp} type="datetime-local"
|
||||
name="history_timestamp"
|
||||
onChange={(e, { name, value }) => {
|
||||
setHistoryTimestamp(value);
|
||||
}} />
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
deleteHistoryLogs().then();
|
||||
}}>清理历史日志</Form.Button>
|
||||
<Divider />
|
||||
<Header as="h3">
|
||||
数据看板
|
||||
</Header>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DataExportEnabled === 'true'}
|
||||
label="启用数据看板(实验性)"
|
||||
name="DataExportEnabled"
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Group>
|
||||
<Form.Input
|
||||
label="数据看板更新间隔(分钟,设置过短会影响数据库性能)"
|
||||
name="DataExportInterval"
|
||||
type={'number'}
|
||||
step="1"
|
||||
min="1"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.DataExportInterval}
|
||||
placeholder="数据看板更新间隔(分钟,设置过短会影响数据库性能)"
|
||||
/>
|
||||
<Form.Select
|
||||
label="数据看板默认时间粒度(仅修改展示粒度,统计精确到小时)"
|
||||
options={timeOptions}
|
||||
name="DataExportDefaultTime"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.DataExportDefaultTime}
|
||||
placeholder="数据看板默认时间粒度"
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider />
|
||||
<Header as="h3">
|
||||
监控设置
|
||||
</Header>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Input
|
||||
label="最长响应时间"
|
||||
name="ChannelDisableThreshold"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.ChannelDisableThreshold}
|
||||
type="number"
|
||||
min="0"
|
||||
placeholder="单位秒,当运行通道全部测试时,超过此时间将自动禁用通道"
|
||||
/>
|
||||
<Form.Input
|
||||
label="额度提醒阈值"
|
||||
name="QuotaRemindThreshold"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.QuotaRemindThreshold}
|
||||
type="number"
|
||||
min="0"
|
||||
placeholder="低于此额度时将发送邮件提醒用户"
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.AutomaticDisableChannelEnabled === 'true'}
|
||||
label="失败时自动禁用通道"
|
||||
name="AutomaticDisableChannelEnabled"
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.AutomaticEnableChannelEnabled === 'true'}
|
||||
label="成功时自动启用通道"
|
||||
name="AutomaticEnableChannelEnabled"
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
submitConfig('monitor').then();
|
||||
}}>保存监控设置</Form.Button>
|
||||
<Divider />
|
||||
<Header as="h3">
|
||||
额度设置
|
||||
</Header>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input
|
||||
label="新用户初始额度"
|
||||
name="QuotaForNewUser"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.QuotaForNewUser}
|
||||
type="number"
|
||||
min="0"
|
||||
placeholder="例如:100"
|
||||
/>
|
||||
<Form.Input
|
||||
label="请求预扣费额度"
|
||||
name="PreConsumedQuota"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.PreConsumedQuota}
|
||||
type="number"
|
||||
min="0"
|
||||
placeholder="请求结束后多退少补"
|
||||
/>
|
||||
<Form.Input
|
||||
label="邀请新用户奖励额度"
|
||||
name="QuotaForInviter"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.QuotaForInviter}
|
||||
type="number"
|
||||
min="0"
|
||||
placeholder="例如:2000"
|
||||
/>
|
||||
<Form.Input
|
||||
label="新用户使用邀请码奖励额度"
|
||||
name="QuotaForInvitee"
|
||||
onChange={handleInputChange}
|
||||
autoComplete="new-password"
|
||||
value={inputs.QuotaForInvitee}
|
||||
type="number"
|
||||
min="0"
|
||||
placeholder="例如:1000"
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
submitConfig('quota').then();
|
||||
}}>保存额度设置</Form.Button>
|
||||
<Divider />
|
||||
<Header as="h3">
|
||||
倍率设置
|
||||
</Header>
|
||||
<Form.Group widths="equal">
|
||||
<Form.TextArea
|
||||
label="模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)"
|
||||
name="ModelPrice"
|
||||
onChange={handleInputChange}
|
||||
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
autoComplete="new-password"
|
||||
value={inputs.ModelPrice}
|
||||
placeholder='为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1,一次消耗0.1刀'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths="equal">
|
||||
<Form.TextArea
|
||||
label="模型倍率"
|
||||
name="ModelRatio"
|
||||
onChange={handleInputChange}
|
||||
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
autoComplete="new-password"
|
||||
value={inputs.ModelRatio}
|
||||
placeholder="为一个 JSON 文本,键为模型名称,值为倍率"
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths="equal">
|
||||
<Form.TextArea
|
||||
label="分组倍率"
|
||||
name="GroupRatio"
|
||||
onChange={handleInputChange}
|
||||
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
autoComplete="new-password"
|
||||
value={inputs.GroupRatio}
|
||||
placeholder="为一个 JSON 文本,键为分组名称,值为倍率"
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
submitConfig('ratio').then();
|
||||
}}>保存倍率设置</Form.Button>
|
||||
</Form>
|
||||
</Grid.Column>
|
||||
</Grid>
|
||||
)
|
||||
;
|
||||
};
|
||||
|
||||
export default OperationSetting;
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Button, Divider, Form, Grid, Header, Message, Modal } from 'semantic-ui-react';
|
||||
import React, { useEffect, useRef, useState } from 'react';
|
||||
import { Banner, Button, Col, Form, Row } from '@douyinfe/semi-ui';
|
||||
import { API, showError, showSuccess } from '../helpers';
|
||||
import { marked } from 'marked';
|
||||
|
||||
const OtherSetting = () => {
|
||||
let [inputs, setInputs] = useState({
|
||||
Footer: '',
|
||||
Notice: '',
|
||||
About: '',
|
||||
SystemName: '',
|
||||
Logo: '',
|
||||
Footer: '',
|
||||
About: '',
|
||||
HomePageContent: ''
|
||||
});
|
||||
let [loading, setLoading] = useState(false);
|
||||
@@ -19,25 +19,6 @@ const OtherSetting = () => {
|
||||
content: ''
|
||||
});
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (item.key in inputs) {
|
||||
newInputs[item.key] = item.value;
|
||||
}
|
||||
});
|
||||
setInputs(newInputs);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
getOptions().then();
|
||||
}, []);
|
||||
|
||||
const updateOption = async (key, value) => {
|
||||
setLoading(true);
|
||||
@@ -54,33 +35,103 @@ const OtherSetting = () => {
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const handleInputChange = async (e, { name, value }) => {
|
||||
const [loadingInput, setLoadingInput] = useState({
|
||||
Notice: false,
|
||||
SystemName: false,
|
||||
Logo: false,
|
||||
HomePageContent: false,
|
||||
About: false,
|
||||
Footer: false
|
||||
});
|
||||
const handleInputChange = async (value, e) => {
|
||||
const name = e.target.id;
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
};
|
||||
|
||||
// 通用设置
|
||||
const formAPISettingGeneral = useRef();
|
||||
// 通用设置 - Notice
|
||||
const submitNotice = async () => {
|
||||
await updateOption('Notice', inputs.Notice);
|
||||
try {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, Notice: true }));
|
||||
await updateOption('Notice', inputs.Notice);
|
||||
showSuccess('公告已更新');
|
||||
} catch (error) {
|
||||
console.error('公告更新失败', error);
|
||||
showError('公告更新失败');
|
||||
} finally {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, Notice: false }));
|
||||
}
|
||||
};
|
||||
|
||||
const submitFooter = async () => {
|
||||
await updateOption('Footer', inputs.Footer);
|
||||
};
|
||||
|
||||
// 个性化设置
|
||||
const formAPIPersonalization = useRef();
|
||||
// 个性化设置 - SystemName
|
||||
const submitSystemName = async () => {
|
||||
await updateOption('SystemName', inputs.SystemName);
|
||||
try {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, SystemName: true }));
|
||||
await updateOption('SystemName', inputs.SystemName);
|
||||
showSuccess('系统名称已更新');
|
||||
} catch (error) {
|
||||
console.error('系统名称更新失败', error);
|
||||
showError('系统名称更新失败');
|
||||
} finally {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, SystemName: false }));
|
||||
}
|
||||
};
|
||||
|
||||
// 个性化设置 - Logo
|
||||
const submitLogo = async () => {
|
||||
await updateOption('Logo', inputs.Logo);
|
||||
try {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, Logo: true }));
|
||||
await updateOption('Logo', inputs.Logo);
|
||||
showSuccess('Logo 已更新');
|
||||
} catch (error) {
|
||||
console.error('Logo 更新失败', error);
|
||||
showError('Logo 更新失败');
|
||||
} finally {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, Logo: false }));
|
||||
}
|
||||
};
|
||||
|
||||
const submitAbout = async () => {
|
||||
await updateOption('About', inputs.About);
|
||||
};
|
||||
|
||||
// 个性化设置 - 首页内容
|
||||
const submitOption = async (key) => {
|
||||
await updateOption(key, inputs[key]);
|
||||
try {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, HomePageContent: true }));
|
||||
await updateOption(key, inputs[key]);
|
||||
showSuccess('首页内容已更新');
|
||||
} catch (error) {
|
||||
console.error('首页内容更新失败', error);
|
||||
showError('首页内容更新失败');
|
||||
} finally {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, HomePageContent: false }));
|
||||
}
|
||||
};
|
||||
// 个性化设置 - 关于
|
||||
const submitAbout = async () => {
|
||||
try {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, About: true }));
|
||||
await updateOption('About', inputs.About);
|
||||
showSuccess('关于内容已更新');
|
||||
} catch (error) {
|
||||
console.error('关于内容更新失败', error);
|
||||
showError('关于内容更新失败');
|
||||
} finally {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, About: false }));
|
||||
}
|
||||
};
|
||||
// 个性化设置 - 页脚
|
||||
const submitFooter = async () => {
|
||||
try {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, Footer: true }));
|
||||
await updateOption('Footer', inputs.Footer);
|
||||
showSuccess('页脚内容已更新');
|
||||
} catch (error) {
|
||||
console.error('页脚内容更新失败', error);
|
||||
showError('页脚内容更新失败');
|
||||
} finally {
|
||||
setLoadingInput((loadingInput) => ({ ...loadingInput, Footer: false }));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
const openGitHubRelease = () => {
|
||||
window.location =
|
||||
@@ -102,82 +153,102 @@ const OtherSetting = () => {
|
||||
setShowUpdateModal(true);
|
||||
}
|
||||
};
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (item.key in inputs) {
|
||||
newInputs[item.key] = item.value;
|
||||
}
|
||||
});
|
||||
setInputs(newInputs);
|
||||
formAPISettingGeneral.current.setValues(newInputs);
|
||||
formAPIPersonalization.current.setValues(newInputs);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
getOptions();
|
||||
}, []);
|
||||
|
||||
|
||||
return (
|
||||
<Grid columns={1}>
|
||||
<Grid.Column>
|
||||
<Form loading={loading}>
|
||||
<Header as='h3'>通用设置</Header>
|
||||
{/*<Form.Button onClick={checkUpdate}>检查更新</Form.Button>*/}
|
||||
<Form.Group widths='equal'>
|
||||
<Row>
|
||||
<Col span={24}>
|
||||
{/* 通用设置 */}
|
||||
<Form values={inputs} getFormApi={formAPI => formAPISettingGeneral.current = formAPI}
|
||||
style={{ marginBottom: 15 }}>
|
||||
<Form.Section text={'通用设置'}>
|
||||
<Form.TextArea
|
||||
label='公告'
|
||||
placeholder='在此输入新的公告内容,支持 Markdown & HTML 代码'
|
||||
value={inputs.Notice}
|
||||
name='Notice'
|
||||
label={'公告'}
|
||||
placeholder={'在此输入新的公告内容,支持 Markdown & HTML 代码'}
|
||||
field={'Notice'}
|
||||
onChange={handleInputChange}
|
||||
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
style={{ fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
autosize={{ minRows: 6, maxRows: 12 }}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitNotice}>保存公告</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>个性化设置</Header>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.Input
|
||||
label='系统名称'
|
||||
placeholder='在此输入系统名称'
|
||||
value={inputs.SystemName}
|
||||
name='SystemName'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitSystemName}>设置系统名称</Form.Button>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.Input
|
||||
label='Logo 图片地址'
|
||||
placeholder='在此输入 Logo 图片地址'
|
||||
value={inputs.Logo}
|
||||
name='Logo'
|
||||
type='url'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitLogo}>设置 Logo</Form.Button>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='首页内容'
|
||||
placeholder='在此输入首页内容,支持 Markdown & HTML 代码,设置后首页的状态信息将不再显示。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为首页。'
|
||||
value={inputs.HomePageContent}
|
||||
name='HomePageContent'
|
||||
onChange={handleInputChange}
|
||||
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => submitOption('HomePageContent')}>保存首页内容</Form.Button>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='关于'
|
||||
placeholder='在此输入新的关于内容,支持 Markdown & HTML 代码。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为关于页面。'
|
||||
value={inputs.About}
|
||||
name='About'
|
||||
onChange={handleInputChange}
|
||||
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitAbout}>保存关于</Form.Button>
|
||||
<Message>移除 One API 的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目。</Message>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.Input
|
||||
label='页脚'
|
||||
placeholder='在此输入新的页脚,留空则使用默认页脚,支持 HTML 代码'
|
||||
value={inputs.Footer}
|
||||
name='Footer'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitFooter}>设置页脚</Form.Button>
|
||||
<Button onClick={submitNotice} loading={loadingInput['Notice']}>设置公告</Button>
|
||||
</Form.Section>
|
||||
</Form>
|
||||
</Grid.Column>
|
||||
{/* 个性化设置 */}
|
||||
<Form values={inputs} getFormApi={formAPI => formAPIPersonalization.current = formAPI}
|
||||
style={{ marginBottom: 15 }}>
|
||||
<Form.Section text={'个性化设置'}>
|
||||
<Form.Input
|
||||
label={'系统名称'}
|
||||
placeholder={'在此输入系统名称'}
|
||||
field={'SystemName'}
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Button onClick={submitSystemName} loading={loadingInput['SystemName']}>设置系统名称</Button>
|
||||
<Form.Input
|
||||
label={'Logo 图片地址'}
|
||||
placeholder={'在此输入 Logo 图片地址'}
|
||||
field={'Logo'}
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Button onClick={submitLogo} loading={loadingInput['Logo']}>设置 Logo</Button>
|
||||
<Form.TextArea
|
||||
label={'首页内容'}
|
||||
placeholder={'在此输入首页内容,支持 Markdown & HTML 代码,设置后首页的状态信息将不再显示。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为首页。'}
|
||||
field={'HomePageContent'}
|
||||
onChange={handleInputChange}
|
||||
style={{ fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
autosize={{ minRows: 6, maxRows: 12 }}
|
||||
/>
|
||||
<Button onClick={() => submitOption('HomePageContent')}
|
||||
loading={loadingInput['HomePageContent']}>设置首页内容</Button>
|
||||
<Form.TextArea
|
||||
label={'关于'}
|
||||
placeholder={'在此输入新的关于内容,支持 Markdown & HTML 代码。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为关于页面。'}
|
||||
field={'About'}
|
||||
onChange={handleInputChange}
|
||||
style={{ fontFamily: 'JetBrains Mono, Consolas' }}
|
||||
autosize={{ minRows: 6, maxRows: 12 }}
|
||||
/>
|
||||
<Button onClick={submitAbout} loading={loadingInput['About']}>设置关于</Button>
|
||||
{/* */}
|
||||
<Banner
|
||||
fullMode={false}
|
||||
type="info"
|
||||
description="移除 One API 的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目。"
|
||||
closeIcon={null}
|
||||
style={{ marginTop: 15 }}
|
||||
/>
|
||||
<Form.Input
|
||||
label={'页脚'}
|
||||
placeholder={'在此输入新的页脚,留空则使用默认页脚,支持 HTML 代码'}
|
||||
field={'Footer'}
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Button onClick={submitFooter} loading={loadingInput['Footer']}>设置页脚</Button>
|
||||
</Form.Section>
|
||||
</Form>
|
||||
</Col>
|
||||
{/*<Modal*/}
|
||||
{/* onClose={() => setShowUpdateModal(false)}*/}
|
||||
{/* onOpen={() => setShowUpdateModal(true)}*/}
|
||||
@@ -200,7 +271,7 @@ const OtherSetting = () => {
|
||||
{/* />*/}
|
||||
{/* </Modal.Actions>*/}
|
||||
{/*</Modal>*/}
|
||||
</Grid>
|
||||
</Row>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Button, Form, Grid, Header, Image, Segment } from 'semantic-ui-react';
|
||||
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
|
||||
import { API, copy, showError, showNotice } from '../helpers';
|
||||
import { useSearchParams } from 'react-router-dom';
|
||||
|
||||
const PasswordResetConfirm = () => {
|
||||
const [inputs, setInputs] = useState({
|
||||
email: '',
|
||||
token: '',
|
||||
token: ''
|
||||
});
|
||||
const { email, token } = inputs;
|
||||
|
||||
@@ -23,7 +23,7 @@ const PasswordResetConfirm = () => {
|
||||
let email = searchParams.get('email');
|
||||
setInputs({
|
||||
token,
|
||||
email,
|
||||
email
|
||||
});
|
||||
}, []);
|
||||
|
||||
@@ -37,7 +37,7 @@ const PasswordResetConfirm = () => {
|
||||
setDisableButton(false);
|
||||
setCountdown(30);
|
||||
}
|
||||
return () => clearInterval(countdownInterval);
|
||||
return () => clearInterval(countdownInterval);
|
||||
}, [disableButton, countdown]);
|
||||
|
||||
async function handleSubmit(e) {
|
||||
@@ -46,7 +46,7 @@ const PasswordResetConfirm = () => {
|
||||
setLoading(true);
|
||||
const res = await API.post(`/api/user/reset`, {
|
||||
email,
|
||||
token,
|
||||
token
|
||||
});
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
@@ -59,44 +59,44 @@ const PasswordResetConfirm = () => {
|
||||
}
|
||||
setLoading(false);
|
||||
}
|
||||
|
||||
|
||||
return (
|
||||
<Grid textAlign='center' style={{ marginTop: '48px' }}>
|
||||
<Grid textAlign="center" style={{ marginTop: '48px' }}>
|
||||
<Grid.Column style={{ maxWidth: 450 }}>
|
||||
<Header as='h2' color='' textAlign='center'>
|
||||
<Image src='/logo.png' /> 密码重置确认
|
||||
<Header as="h2" color="" textAlign="center">
|
||||
<Image src="/logo.png" /> 密码重置确认
|
||||
</Header>
|
||||
<Form size='large'>
|
||||
<Form size="large">
|
||||
<Segment>
|
||||
<Form.Input
|
||||
fluid
|
||||
icon='mail'
|
||||
iconPosition='left'
|
||||
placeholder='邮箱地址'
|
||||
name='email'
|
||||
icon="mail"
|
||||
iconPosition="left"
|
||||
placeholder="邮箱地址"
|
||||
name="email"
|
||||
value={email}
|
||||
readOnly
|
||||
/>
|
||||
{newPassword && (
|
||||
<Form.Input
|
||||
fluid
|
||||
icon='lock'
|
||||
iconPosition='left'
|
||||
placeholder='新密码'
|
||||
name='newPassword'
|
||||
value={newPassword}
|
||||
readOnly
|
||||
onClick={(e) => {
|
||||
e.target.select();
|
||||
navigator.clipboard.writeText(newPassword);
|
||||
showNotice(`密码已复制到剪贴板:${newPassword}`);
|
||||
}}
|
||||
/>
|
||||
fluid
|
||||
icon="lock"
|
||||
iconPosition="left"
|
||||
placeholder="新密码"
|
||||
name="newPassword"
|
||||
value={newPassword}
|
||||
readOnly
|
||||
onClick={(e) => {
|
||||
e.target.select();
|
||||
navigator.clipboard.writeText(newPassword);
|
||||
showNotice(`密码已复制到剪贴板:${newPassword}`);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<Button
|
||||
color='green'
|
||||
color="green"
|
||||
fluid
|
||||
size='large'
|
||||
size="large"
|
||||
onClick={handleSubmit}
|
||||
loading={loading}
|
||||
disabled={disableButton}
|
||||
@@ -107,7 +107,7 @@ const PasswordResetConfirm = () => {
|
||||
</Form>
|
||||
</Grid.Column>
|
||||
</Grid>
|
||||
);
|
||||
);
|
||||
};
|
||||
|
||||
export default PasswordResetConfirm;
|
||||
|
||||
@@ -56,19 +56,19 @@ const PasswordResetForm = () => {
|
||||
}
|
||||
|
||||
return (
|
||||
<Grid textAlign='center' style={{ marginTop: '48px' }}>
|
||||
<Grid textAlign="center" style={{ marginTop: '48px' }}>
|
||||
<Grid.Column style={{ maxWidth: 450 }}>
|
||||
<Header as='h2' color='' textAlign='center'>
|
||||
<Image src='/logo.png' /> 密码重置
|
||||
<Header as="h2" color="" textAlign="center">
|
||||
<Image src="/logo.png" /> 密码重置
|
||||
</Header>
|
||||
<Form size='large'>
|
||||
<Form size="large">
|
||||
<Segment>
|
||||
<Form.Input
|
||||
fluid
|
||||
icon='mail'
|
||||
iconPosition='left'
|
||||
placeholder='邮箱地址'
|
||||
name='email'
|
||||
icon="mail"
|
||||
iconPosition="left"
|
||||
placeholder="邮箱地址"
|
||||
name="email"
|
||||
value={email}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
@@ -83,9 +83,9 @@ const PasswordResetForm = () => {
|
||||
<></>
|
||||
)}
|
||||
<Button
|
||||
color='green'
|
||||
color="green"
|
||||
fluid
|
||||
size='large'
|
||||
size="large"
|
||||
onClick={handleSubmit}
|
||||
loading={loading}
|
||||
disabled={disableButton}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@ import { history } from '../helpers';
|
||||
|
||||
function PrivateRoute({ children }) {
|
||||
if (!localStorage.getItem('user')) {
|
||||
return <Navigate to='/login' state={{ from: history.location }} />;
|
||||
return <Navigate to="/login" state={{ from: history.location }} />;
|
||||
}
|
||||
return children;
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user