Compare commits

..

No commits in common. "main" and "v0.1.3.1" have entirely different histories.

295 changed files with 14402 additions and 39545 deletions

View File

@ -1,5 +1,5 @@
blank_issues_enabled: false
contact_links:
- name: 交流社区
url: https://linux.do
about: 项目交流社区
- name: 项目群聊
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
about: QQ 群629454374

View File

@ -1,6 +1,10 @@
name: Publish Docker image (amd64)
on:
push:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:
@ -39,7 +43,7 @@ jobs:
uses: docker/metadata-action@v4
with:
images: |
pengzhile/new-api
calciumion/new-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images

View File

@ -49,7 +49,7 @@ jobs:
uses: docker/metadata-action@v4
with:
images: |
pengzhile/new-api
calciumion/new-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images

1
.gitignore vendored
View File

@ -6,4 +6,3 @@ upload
build
*.db-journal
logs
web/dist

View File

@ -1,13 +1,13 @@
FROM oven/bun:latest AS builder
FROM node:16 as builder
WORKDIR /build
COPY web/package.json .
RUN bun install
RUN npm install
COPY ./web .
COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang:1.21 AS builder2
FROM golang AS builder2
ENV GO111MODULE=on \
CGO_ENABLED=1 \
@ -17,7 +17,7 @@ WORKDIR /build
ADD go.mod go.sum ./
RUN go mod download
COPY . .
COPY --from=builder /build/dist ./web/dist
COPY --from=builder /build/build ./web/build
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine
@ -25,7 +25,7 @@ FROM alpine
RUN apk update \
&& apk upgrade \
&& apk add --no-cache ca-certificates tzdata \
&& update-ca-certificates
&& update-ca-certificates 2>/dev/null || true
COPY --from=builder2 /build/one-api /
EXPOSE 3000

214
LICENSE
View File

@ -1,201 +1,21 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
MIT License
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
Copyright (c) 2024 Calcium-Ion
1. Definitions.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,14 +0,0 @@
FRONTEND_DIR = ./web
BACKEND_DIR = .
.PHONY: all build-frontend start-backend
all: build-frontend start-backend
build-frontend:
@echo "Building frontend..."
@cd $(FRONTEND_DIR) && yarn install --network-timeout 1000000 && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) yarn build
start-backend:
@echo "Starting backend dev server..."
@cd $(BACKEND_DIR) && go run main.go &

View File

@ -2,81 +2,290 @@
**简介**:Midjourney Proxy API文档
## 接口列表
支持的接口如下:
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
+ [x] /mj/submit/action 仅midjourney-proxy-plus支持下同
+ [x] /mj/submit/modal
+ [x] /mj/submit/shorten
+ [x] /mj/task/{id}/image-seed
+ [x] /mj/insight-face/swap InsightFace
## 模型列表
### midjourney-proxy支持
- mj_imagine (绘图)
- mj_variation (变换)
- mj_reroll (重绘)
- mj_blend (混合)
- mj_upscale (放大)
- mj_describe (图生文)
### 仅midjourney-proxy-plus支持
- mj_zoom (比例变焦)
- mj_shorten (提示词缩短)
- mj_modal (窗口提交局部重绘和自定义比例变焦必须和mj_modal一同添加)
- mj_inpaint (局部重绘提交必须和mj_modal一同添加)
- mj_custom_zoom (自定义比例变焦必须和mj_modal一同添加)
- mj_high_variation (强变换)
- mj_low_variation (弱变换)
- mj_pan (平移)
- swap_face (换脸)
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
```json
{
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05
"mj_upscale": 0.05
}
```
其中mj_inpaint和mj_custom_zoom的价格设置为0是因为这两个模型需要搭配mj_modal使用所以价格由mj_modal决定。
## 渠道设置
### 对接 midjourney-proxy(plus)
1.
部署Midjourney-Proxy并配置好midjourney账号等强烈建议设置密钥[项目地址](https://github.com/novicezk/midjourney-proxy)
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**如果是plus版本选择**Midjourney Proxy Plus**
,模型请参考上方模型列表
3. **代理**填写midjourney-proxy部署的地址例如http://localhost:8080
### 对接 midjourney-proxy
1. 部署Midjourney-Proxy并配置好midjourney账号等强烈建议设置密钥[项目地址](https://github.com/novicezk/midjourney-proxy)
2. 在渠道管理中添加渠道渠道类型选择Midjourney Proxy模型选择midjourney
3. 地址填写midjourney-proxy部署的地址例如http://localhost:8080
4. 密钥填写midjourney-proxy的密钥如果没有设置密钥可以随便填
### 对接上游new api
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
2. **代理**填写上游new api的地址例如http://localhost:3000
1. 在渠道管理中添加渠道渠道类型选择Midjourney Proxy模型选择midjourney
2. 地址填写上游new api的地址例如http://localhost:8080
3. 密钥填写上游new api的密钥
## 任务提交
### 绘图变化
**接口地址**:`/mj/submit/change`
**请求方式**:`POST`
**请求数据类型**:`application/json`
**响应数据类型**:`*/*`
**接口描述**:
**请求示例**:
```javascript
{
"action"
:
"UPSCALE",
"index"
:
1,
"notifyHook"
:
"",
"state"
:
"",
"taskId"
:
"1320098173412546"
}
```
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------|
| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 |
|   action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | |
|   index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | |
|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
|   state | 自定义参数 | | false | string | |
|   taskId | 任务ID | | true | string | |
**响应状态**:
| 状态码 | 说明 | schema |
|-----|--------------|--------|
| 200 | OK | 提交结果 |
| 201 | Created | |
| 401 | Unauthorized | |
| 403 | Forbidden | |
| 404 | Not Found | |
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
|-------------|-------------------------------------------|----------------|----------------|
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
| description | 描述 | string | |
| properties | 扩展字段 | object | |
| result | 任务ID | string | |
**响应示例**:
```javascript
{
"code"
:
1,
"description"
:
"提交成功",
"properties"
:
{
}
,
"result"
:
1320098173412546
}
```
### 提交Imagine任务
**接口地址**:`/mj/submit/imagine`
**请求方式**:`POST`
**请求数据类型**:`application/json`
**响应数据类型**:`*/*`
**接口描述**:
**请求示例**:
```javascript
{
"base64"
:
"",
"notifyHook"
:
"",
"prompt"
:
"Cat",
"state"
:
""
}
```
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|------------------------|-------------------------|------|-------|-------------|-------------|
| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 |
|   base64 | 垫图base64 | | false | string | |
|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
|   prompt | 提示词 | | true | string | |
|   state | 自定义参数 | | false | string | |
**响应状态**:
| 状态码 | 说明 | schema |
|-----|--------------|--------|
| 200 | OK | 提交结果 |
| 201 | Created | |
| 401 | Unauthorized | |
| 403 | Forbidden | |
| 404 | Not Found | |
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
|-------------|-------------------------------------------|----------------|----------------|
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
| description | 描述 | string | |
| properties | 扩展字段 | object | |
| result | 任务ID | string | |
**响应示例**:
```javascript
{
"code"
:
1,
"description"
:
"提交成功",
"properties"
:
{
}
,
"result"
:
1320098173412546
}
```
## 任务查询
### 指定ID获取任务
**接口地址**:`/mj/task/{id}/fetch`
**请求方式**:`GET`
**请求数据类型**:`application/x-www-form-urlencoded`
**响应数据类型**:`*/*`
**接口描述**:
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|------|------|------|-------|--------|--------|
| id | 任务ID | path | false | string | |
**响应状态**:
| 状态码 | 说明 | schema |
|-----|--------------|--------|
| 200 | OK | 任务 |
| 401 | Unauthorized | |
| 403 | Forbidden | |
| 404 | Not Found | |
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
|-------------|----------------------------------------------------------|----------------|----------------|
| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | |
| description | 任务描述 | string | |
| failReason | 失败原因 | string | |
| finishTime | 结束时间 | integer(int64) | integer(int64) |
| id | 任务ID | string | |
| imageUrl | 图片url | string | |
| progress | 任务进度 | string | |
| prompt | 提示词 | string | |
| promptEn | 提示词-英文 | string | |
| startTime | 开始执行时间 | integer(int64) | integer(int64) |
| state | 自定义参数 | string | |
| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | |
| submitTime | 提交时间 | integer(int64) | integer(int64) |
**响应示例**:
```javascript
{
"action"
:
"",
"description"
:
"",
"failReason"
:
"",
"finishTime"
:
0,
"id"
:
"",
"imageUrl"
:
"",
"progress"
:
"",
"prompt"
:
"",
"promptEn"
:
"",
"startTime"
:
0,
"state"
:
"",
"status"
:
"",
"submitTime"
:
0
}
```

130
README.md
View File

@ -1,84 +1,45 @@
<div align="center">
![new-api](/web/public/logo.png)
# New API
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
> [!NOTE]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
> [!IMPORTANT]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
> [!WARNING]
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> [!TIP]
> 最新版Docker镜像`calciumion/new-api:latest`
> 默认账号root 密码123456
> 更新指令:
> ```
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
## 主要变更
此分叉版本的主要变更如下:
> [!NOTE]
> 最新版Docker镜像 calciumion/new-api:latest
> 更新指令 docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
## 此分叉版本的主要变更
1. 全新的UI界面部分界面还待更新
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持:
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付
4. 支持用key查询使用额度:
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用情况,方便二次分销
5. 渠道显示已使用额度,支持指定组织访问
6. 分页支持选择每页显示数量
7. 兼容原版One API的数据库可直接使用原版数据库one-api.db
8. 支持模型按次数收费,可在 系统设置-运营设置 中设置
9. 支持渠道**加权随机**
10. 数据看板
11. 可设置令牌能调用的模型
12. 支持Telegram授权登录。
1. 系统设置-配置登录注册-允许通过Telegram登录
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
3. 选择你的bot然后输入http(s)://你的网站地址/login
4. Telegram Bot 名称是bot username 去掉@后的字符串
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
14. 支持Rerank模型目前仅兼容Cohere和Jina可接入Dify[对接文档](Rerank.md)
7. 支持 gpt-4-1106-vision-previewdall-e-3tts-1
8. 支持第三方模型 **gps** gpt-4-gizmo-*在渠道中添加自定义模型gpt-4-gizmo-*即可
9. 兼容原版One API的数据库可直接使用原版数据库one-api.db
10. 支持模型按次数收费,可在 系统设置-运营设置 中设置
11. 支持gemini-progemini-pro-vision模型
12. 支持渠道**加权随机**
13. 数据看板
14. 可设置令牌能调用的模型
## 模型支持
此版本额外支持以下模型:
1. 第三方模型 **gps** gpt-4-gizmo-*, g-*
2. 智谱glm-4vglm-4v识图
3. Anthropic Claude 3
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
6. [零一万物](https://platform.lingyiwanwu.com/)
7. 自定义渠道,支持填入完整调用地址
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md)
10. Dify
11. Vertex AI目前兼容ClaudeGeminiLlama3.1
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 比原版One API多出的配置
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
- `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL``STRICT`,默认为 `NONE`
## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
- 远程数据库MySQL 版本 >= 5.7.8PgSQL 版本 >= 9.6
### 基于 Docker 进行部署
```shell
# 使用 SQLite 的部署命令:
@ -97,47 +58,26 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 注意数据库要开启远程访问并且只允许服务器IP访问
```
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
### 为什么有的时候没有重试
这些错误码不会重试400504524
### 我想让400也重试
在`渠道->编辑`中,将`状态码复写`改为
```json
{
"400": "500"
}
```
可以实现400错误转为500错误从而重试
## Midjourney接口设置文档
[对接文档](Midjourney.md)
## Suno接口设置文档
[对接文档](Suno.md)
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
## 界面截图
![796df8d287b7b7bd7853b2497e7df511](https://github.com/user-attachments/assets/255b5e97-2d3a-4434-b4fa-e922ad88ff5a)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/d1ac216e-0804-4105-9fdc-66b35022d861)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/f4f40ed4-8ccb-43d7-a580-90677827646d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/90d7d763-6a77-4b36-9f76-2bb30f18583d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/e414228a-3c35-429a-b298-6451d76d9032)
夜间模式
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## 相关项目
- [One API](https://github.com/songquanpeng/one-api):原版项目
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)Midjourney接口支持
- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代 AI 一站式 B/C 端解决方案
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)用key查询使用额度
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/5b3228e8-2556-44f7-97d6-4f8d8ee6effa)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## Star History

View File

@ -1,62 +0,0 @@
# Rerank API文档
**简介**:Rerank API文档
## 接入Dify
模型供应商选择Jina按要求填写模型信息即可接入Dify。
## 请求方式
Post: /v1/rerank
Request:
```json
{
"model": "rerank-multilingual-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
]
}
```
Response:
```json
{
"results": [
{
"document": {
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
},
"index": 2,
"relevance_score": 0.9999702
},
{
"document": {
"text": "Carson City is the capital city of the American state of Nevada."
},
"index": 0,
"relevance_score": 0.67800725
},
{
"document": {
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
},
"index": 3,
"relevance_score": 0.02800752
}
],
"usage": {
"prompt_tokens": 158,
"completion_tokens": 0,
"total_tokens": 158
}
}
```

44
Suno.md
View File

@ -1,44 +0,0 @@
# Suno API文档
**简介**:Suno API文档
## 接口列表
支持的接口如下:
+ [x] /suno/submit/music
+ [x] /suno/submit/lyrics
+ [x] /suno/fetch
+ [x] /suno/fetch/:id
## 模型列表
### Suno API支持
- suno_music (自定义模式、灵感模式、续写)
- suno_lyrics (生成歌词)
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
```json
{
"suno_music": 0.3,
"suno_lyrics": 0.01
}
```
## 渠道设置
### 对接 Suno API
1.
部署 Suno API并配置好suno账号等强烈建议设置密钥[项目地址](https://github.com/Suno-API/Suno-API)
2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
,模型请参考上方模型列表
3. **代理**填写 Suno API 部署的地址例如http://localhost:8080
4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
### 对接上游new api
1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
2. **代理**填写上游new api的地址例如http://localhost:3000
3. 密钥填写上游new api的密钥

View File

@ -9,20 +9,14 @@ import (
"github.com/google/uuid"
)
// Pay Settings
var StripeApiSecret = ""
var StripeWebhookSecret = ""
var StripePriceId = ""
var PaymentEnabled = false
var StripeUnitPrice = 8.0
var MinTopUp = 5
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API"
var ServerAddress = "http://localhost:3000"
var OutProxyUrl = ""
var PayAddress = ""
var EpayId = ""
var EpayKey = ""
var Price = 7.3
var Footer = ""
var Logo = ""
var TopUpLink = ""
@ -32,11 +26,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var DrawingEnabled = true
var TaskEnabled = true
var DataExportEnabled = true
var DataExportInterval = 5 // unit: minute
var DataExportDefaultTime = "hour" // unit: minute
var DefaultCollapseSidebar = false // default value of collapse sidebar
// Any options with "Secret", "Token" in its key won't be return by GetOptions
@ -52,15 +44,11 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var LinuxDoOAuthEnabled = false
var WeChatAuthEnabled = false
var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var UserSelfDeletionEnabled = false
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
@ -80,7 +68,6 @@ var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPSSLEnabled = false
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
@ -88,10 +75,6 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var LinuxDoClientId = ""
var LinuxDoClientSecret = ""
var LinuxDoMinLevel = 0
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
@ -99,9 +82,6 @@ var WeChatAccountQRCodeImageURL = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var TelegramBotToken = ""
var TelegramBotName = ""
var QuotaForNewUser = 0
var QuotaForInviter = 0
var QuotaForInvitee = 0
@ -120,17 +100,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
const (
RequestIdKey = "X-Oneapi-Request-Id"
@ -143,10 +120,6 @@ const (
RoleRootUser = 100
)
func IsValidateRole(role int) bool {
return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
}
var (
FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser
@ -157,10 +130,10 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
@ -200,19 +173,13 @@ const (
ChannelStatusAutoDisabled = 3
)
const (
TopUpStatusPending = "pending"
TopUpStatusSuccess = "success"
TopUpStatusExpired = "expired"
)
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeMidjourney = 2
ChannelTypeAPI2D = 2
ChannelTypeAzure = 3
ChannelTypeOllama = 4
ChannelTypeMidjourneyPlus = 5
ChannelTypeCloseAI = 4
ChannelTypeOpenAISB = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
@ -232,65 +199,32 @@ const (
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
"https://api.perplexity.ai", //27
"", //28
"", //29
"", //30
"https://api.lingyiwanwu.com", //31
"", //32
"", //33
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"", //24
}

View File

@ -2,6 +2,5 @@ package common
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
var SQLitePath = "one-api.db?_busy_timeout=5000"

View File

@ -1,40 +0,0 @@
package common
import (
"errors"
"net/smtp"
"strings"
)
type outlookAuth struct {
username, password string
}
func LoginAuth(username, password string) smtp.Auth {
return &outlookAuth{username, password}
}
func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte{}, nil
}
func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:":
return []byte(a.username), nil
case "Password:":
return []byte(a.password), nil
default:
return nil, errors.New("unknown fromServer")
}
}
return nil, nil
}
func isOutlookServer(server string) bool {
// 兼容多地区的outlook邮箱和ofb邮箱
// 其实应该加一个Option来区分是否用LOGIN的方式登录
// 先临时兼容一下
return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
}

View File

@ -9,31 +9,22 @@ import (
"time"
)
func generateMessageID() string {
domain := strings.Split(SMTPAccount, "@")[1]
return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain)
}
func SendEmail(subject string, receiver string, content string) error {
if SMTPFrom == "" { // for compatibility
SMTPFrom = SMTPAccount
}
if SMTPServer == "" && SMTPAccount == "" {
return fmt.Errorf("SMTP 服务器未配置")
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
"Date: %s\r\n"+
"Message-ID: %s\r\n"+ // 添加 Message-ID 头
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), generateMessageID(), content))
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")
var err error
if SMTPPort == 465 || SMTPSSLEnabled {
if SMTPPort == 465 {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: SMTPServer,
@ -71,9 +62,6 @@ func SendEmail(subject string, receiver string, content string) error {
if err != nil {
return err
}
} else if isOutlookServer(SMTPAccount) {
auth = LoginAuth(SMTPAccount, SMTPToken)
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
} else {
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
}

View File

@ -1,38 +0,0 @@
package common
import (
"fmt"
"os"
"strconv"
)
func GetEnvOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetEnvOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
b, err := strconv.ParseBool(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
return defaultValue
}
return b
}

View File

@ -5,37 +5,18 @@ import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"strings"
)
const KeyRequestBody = "key_request_body"
func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(KeyRequestBody)
if requestBody != nil {
return requestBody.([]byte), nil
}
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, err
}
_ = c.Request.Body.Close()
c.Set(KeyRequestBody, requestBody)
return requestBody.([]byte), nil
}
func UnmarshalBodyReusable(c *gin.Context, v any) error {
requestBody, err := GetRequestBody(c)
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return err
}
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
err = c.Request.Body.Close()
if err != nil {
return err
}
err = json.Unmarshal(requestBody, &v)
if err != nil {
return err
}

View File

@ -3,7 +3,6 @@ package common
import (
"fmt"
"runtime/debug"
"time"
)
func SafeGoroutine(f func()) {
@ -17,7 +16,7 @@ func SafeGoroutine(f func()) {
}()
}
func SafeSendBool(ch chan bool, value bool) (closed bool) {
func SafeSend(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
@ -31,36 +30,3 @@ func SafeSendBool(ch chan bool, value bool) (closed bool) {
// If the code reaches here, then the channel was not closed.
return false
}
func SafeSendString(ch chan string, value string) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = true
}
}()
// This will panic if the channel is closed.
ch <- value
// If the code reaches here, then the channel was not closed.
return false
}
// SafeSendStringTimeout send, return true, else return false
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = false
}
}()
// This will panic if the channel is closed.
select {
case ch <- value:
return true
case <-time.After(time.Duration(timeout) * time.Second):
return false
}
}

View File

@ -1,8 +1,6 @@
package common
import (
"encoding/json"
)
import "encoding/json"
var GroupRatio = map[string]float64{
"default": 1,

View File

@ -1,84 +0,0 @@
package common
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"math/rand"
"time"
)
func Sha256Raw(data string) []byte {
h := sha256.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1Raw(data []byte) []byte {
h := sha1.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1(data string) string {
return hex.EncodeToString(Sha1Raw([]byte(data)))
}
func HmacSha256Raw(message, key []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(message)
return h.Sum(nil)
}
func HmacSha256(message, key string) string {
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
}
func RandomBytes(length int) []byte {
rand.Seed(time.Now().UnixNano())
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return b
}
func RandomString(length int) string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomHex(length int) string {
const chars = "abcdef0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomNumber(length int) string {
const chars = "0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomUUID() string {
all := RandomHex(32)
return all[:8] + "-" + all[8:12] + "-" + all[12:16] + "-" + all[16:20] + "-" + all[20:]
}

View File

@ -5,13 +5,14 @@ import (
"encoding/base64"
"errors"
"fmt"
"golang.org/x/image/webp"
"github.com/chai2010/webp"
"image"
"io"
"net/http"
"strings"
)
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
// 去除base64数据的URL前缀如果有
if idx := strings.Index(base64String, ","); idx != -1 {
base64String = base64String[idx+1:]
@ -21,17 +22,17 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
decodedData, err := base64.StdEncoding.DecodeString(base64String)
if err != nil {
fmt.Println("Error: Failed to decode base64 string")
return image.Config{}, "", "", err
return image.Config{}, "", err
}
// 创建一个bytes.Buffer用于存储解码后的数据
reader := bytes.NewReader(decodedData)
config, format, err := getImageConfig(reader)
return config, format, base64String, err
return config, format, err
}
func IsImageUrl(url string) (bool, error) {
resp, err := ProxiedHttpHead(url, OutProxyUrl)
resp, err := http.Head(url)
if err != nil {
return false, err
}
@ -41,19 +42,15 @@ func IsImageUrl(url string) (bool, error) {
return true, nil
}
// GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := ProxiedHttpGet(url, OutProxyUrl)
resp, err := http.Get(url)
if err != nil {
return
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
@ -66,18 +63,13 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
}
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := ProxiedHttpGet(imageUrl, OutProxyUrl)
response, err := http.Get(imageUrl)
if err != nil {
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err
}
defer response.Body.Close()
if response.StatusCode != 200 {
err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status))
return image.Config{}, "", err
}
var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))

View File

@ -98,11 +98,3 @@ func LogQuota(quota int) string {
return fmt.Sprintf("%d 点额度", quota)
}
}
func LogQuotaF(quota float64) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", quota/QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", int64(quota))
}
}

View File

@ -3,241 +3,100 @@ package common
import (
"encoding/json"
"strings"
"sync"
"time"
)
// from songquanpeng/one-api
const (
USD2RMB = 7.3 // 暂定 1 USD = 7.3 RMB
USD = 500 // $0.002 = 1 -> $1 = 500
RMB = USD / USD2RMB
)
// modelRatio
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
// https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var defaultModelRatio = map[string]float64{
var ModelRatio = map[string]float64{
//"midjourney": 50,
"gpt-4-gizmo-*": 15,
"g-*": 15,
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075,
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
"gpt-4o": 1.25, // $0.005 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
"gpt-4o-2024-11-20": 1.25, // $0.01 / 1K tokens
"o1-preview": 7.5,
"o1-preview-2024-09-12": 7.5,
"o1-mini": 0.55, // $0.0011 / 1K tokens
"o1-mini-2024-09-12": 0.55,
"o3-mini": 0.55,
"o3-mini-2025-01-31": 0.55,
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25,
"babbage-002": 0.2, // $0.0004 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-7-sonnet-20250219": 1.5,
"claude-3-7-sonnet-20250219-thinking": 1.5,
"claude-3-5-haiku-20241022": 0.4,
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20241022": 1.5, // $3 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.18,
"yi-34b-chat-200k": 0.864,
"yi-vl-plus": 0.432,
"yi-large": 20.0 / 1000 * RMB,
"yi-medium": 2.5 / 1000 * RMB,
"yi-vision": 6.0 / 1000 * RMB,
"yi-medium-200k": 12.0 / 1000 * RMB,
"yi-spark": 1.0 / 1000 * RMB,
"yi-large-rag": 25.0 / 1000 * RMB,
"yi-large-turbo": 12.0 / 1000 * RMB,
"yi-large-preview": 20.0 / 1000 * RMB,
"yi-large-rag-preview": 25.0 / 1000 * RMB,
"command": 0.5,
"command-nightly": 0.5,
"command-light": 0.5,
"command-light-nightly": 0.5,
"command-r": 0.25,
"command-r-plus": 1.5,
"command-r-08-2024": 0.075,
"command-r-plus-08-2024": 1.25,
"deepseek-chat": 0.07,
"deepseek-coder": 0.07,
// Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用
"llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD,
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
"llama-3-sonar-large-32k-chat": 1 / 1000 * USD,
"llama-3-sonar-large-32k-online": 1 / 1000 * USD,
"gpt-4-gizmo-*": 15,
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25,
"babbage-002": 0.2, // $0.0004 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8,
"dall-e-3": 16,
"claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
}
var defaultModelPrice = map[string]float64{
"suno_music": 0.1,
"suno_lyrics": 0.01,
"dall-e-2": 0.02,
"dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1,
"g-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_modal": 0.1,
"mj_zoom": 0.1,
"mj_shorten": 0.1,
"mj_high_variation": 0.1,
"mj_low_variation": 0.1,
"mj_pan": 0.1,
"mj_inpaint": 0,
"mj_custom_zoom": 0,
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
"mj_upload": 0.05,
}
var (
modelPriceMap map[string]float64 = nil
modelPriceMapMutex = sync.RWMutex{}
)
var (
modelRatioMap map[string]float64 = nil
modelRatioMapMutex = sync.RWMutex{}
)
var CompletionRatio map[string]float64 = nil
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
"g-*": 2,
"gpt-4-all": 2,
"gpt-4o-all": 2,
}
func GetModelPriceMap() map[string]float64 {
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
if modelPriceMap == nil {
modelPriceMap = defaultModelPrice
}
return modelPriceMap
var ModelPrice = map[string]float64{
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_describe": 0.05,
"mj_upscale": 0.05,
}
func ModelPrice2JSONString() string {
GetModelPriceMap()
jsonBytes, err := json.Marshal(modelPriceMap)
jsonBytes, err := json.Marshal(ModelPrice)
if err != nil {
SysError("error marshalling model price: " + err.Error())
}
@ -245,42 +104,26 @@ func ModelPrice2JSONString() string {
}
func UpdateModelPriceByJSONString(jsonStr string) error {
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
modelPriceMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
ModelPrice = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &ModelPrice)
}
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) {
GetModelPriceMap()
func GetModelPrice(name string, printErr bool) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
}
price, ok := modelPriceMap[name]
price, ok := ModelPrice[name]
if !ok {
if printErr {
SysError("model price not found: " + name)
}
return -1, false
return -1
}
return price, true
}
func GetModelRatioMap() map[string]float64 {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
if modelRatioMap == nil {
modelRatioMap = defaultModelRatio
}
return modelRatioMap
return price
}
func ModelRatio2JSONString() string {
GetModelRatioMap()
jsonBytes, err := json.Marshal(modelRatioMap)
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
}
@ -288,20 +131,15 @@ func ModelRatio2JSONString() string {
}
func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
ModelRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &ModelRatio)
}
func GetModelRatio(name string) float64 {
GetModelRatioMap()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
}
ratio, ok := modelRatioMap[name]
ratio, ok := ModelRatio[name]
if !ok {
SysError("model ratio not found: " + name)
return 30
@ -309,40 +147,7 @@ func GetModelRatio(name string) float64 {
return ratio
}
func DefaultModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(defaultModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
func CompletionRatio2JSONString() string {
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
SysError("error marshalling completion ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
}
if strings.HasPrefix(name, "gpt-3.5") {
if strings.HasSuffix(name, "0125") {
return 3
@ -350,88 +155,28 @@ func GetCompletionRatio(name string) float64 {
if strings.HasSuffix(name, "1106") {
return 2
}
if name == "gpt-3.5-turbo" {
return 3
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
// TODO: clear this after 2023-12-11
now := time.Now()
// https://platform.openai.com/docs/models/continuous-model-upgrades
// if after 2023-12-11, use 2
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
return 2
}
}
return 4.0 / 3.0
return 1.333333
}
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" {
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || "gpt-4o-2024-05-13" == name {
if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
return 3
}
if strings.HasPrefix(name, "gpt-4o") {
return 4
}
return 2
}
if "o1" == name || strings.HasPrefix(name, "o1-") {
return 4
}
if "o3" == name || strings.HasPrefix(name, "o3-") {
return 4
}
if name == "chatgpt-4o-latest" {
return 3
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3
} else if strings.HasPrefix(name, "claude-2") {
return 3
} else if strings.HasPrefix(name, "claude-3") {
return 5
return 3.38
}
if strings.HasPrefix(name, "mistral-") {
return 3
}
if strings.HasPrefix(name, "gemini-") {
return 4
}
if strings.HasPrefix(name, "command") {
switch name {
case "command-r":
return 3
case "command-r-plus":
return 5
case "command-r-08-2024":
return 4
case "command-r-plus-08-2024":
return 4
default:
return 2
}
}
if strings.HasPrefix(name, "deepseek") {
return 2
}
if strings.HasPrefix(name, "ERNIE-Speed-") {
return 2
} else if strings.HasPrefix(name, "ERNIE-Lite-") {
return 2
} else if strings.HasPrefix(name, "ERNIE-Character") {
return 2
} else if strings.HasPrefix(name, "ERNIE-Functions") {
return 2
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.64
case "llama3-8b-8192":
return 2
case "llama3-70b-8192":
return 0.79 / 0.59
}
if ratio, ok := CompletionRatio[name]; ok {
return ratio
if strings.HasPrefix(name, "claude-2") {
return 2.965517
}
return 1
}
func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}

View File

@ -18,8 +18,9 @@ func InitRedisClient() (err error) {
return nil
}
if os.Getenv("SYNC_FREQUENCY") == "" {
SysLog("SYNC_FREQUENCY not set, use default value 60")
SyncFrequency = 60
RedisEnabled = false
SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
}
SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
@ -72,35 +73,6 @@ func RedisDel(key string) error {
}
func RedisDecrease(key string, value int64) error {
// 检查键的剩余生存时间
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil {
// 失败则尝试直接减少
return RDB.DecrBy(context.Background(), key, value).Err()
}
// 如果剩余生存时间大于0则进行减少操作
if ttl > 0 {
ctx := context.Background()
// 开始一个Redis事务
txn := RDB.TxPipeline()
// 减少余额
decrCmd := txn.DecrBy(ctx, key, value)
if err := decrCmd.Err(); err != nil {
return err // 如果减少失败,则直接返回错误
}
// 重新设置过期时间,使用原来的过期时间
txn.Expire(ctx, key, ttl)
// 执行事务
_, err = txn.Exec(ctx)
return err
} else {
_ = RedisDel(key)
}
return nil
ctx := context.Background()
return RDB.DecrBy(ctx, key, value).Err()
}

View File

@ -1,70 +0,0 @@
package common
import (
"encoding/json"
"math/rand"
"strconv"
"unsafe"
)
func GetStringIfEmpty(str string, defaultValue string) string {
if str == "" {
return defaultValue
}
return str
}
func GetRandomString(length int) string {
//rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func MapToJsonStr(m map[string]interface{}) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
if err != nil {
return nil
}
return m
}
func IsJsonStr(str string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}
func StringsContains(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
}
return false
}
// StringToByteSlice []byte only read, panic on append
func StringToByteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}

View File

@ -1,8 +1,6 @@
package common
import (
"encoding/json"
)
import "encoding/json"
var TopupGroupRatio = map[string]float64{
"default": 1,

View File

@ -1,46 +0,0 @@
package common
import (
"encoding/json"
)
var UserUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(UserUsableGroups)
if err != nil {
SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
UserUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
}
func GetUserUsableGroups(userGroup string) map[string]string {
if userGroup == "" {
// 如果userGroup为空返回UserUsableGroups
return UserUsableGroups
}
// 如果userGroup不在UserUsableGroups中返回UserUsableGroups + userGroup
if _, ok := UserUsableGroups[userGroup]; !ok {
appendUserUsableGroups := make(map[string]string)
for k, v := range UserUsableGroups {
appendUserUsableGroups[k] = v
}
appendUserUsableGroups[userGroup] = "用户分组"
return appendUserUsableGroups
}
// 如果userGroup在UserUsableGroups中返回UserUsableGroups
return UserUsableGroups
}
func GroupInUserUsableGroups(groupName string) bool {
_, ok := UserUsableGroups[groupName]
return ok
}

View File

@ -1,25 +1,19 @@
package common
import (
"context"
"errors"
crand "crypto/rand"
"encoding/base64"
"fmt"
"github.com/google/uuid"
"golang.org/x/net/proxy"
"html/template"
"log"
"math/big"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
"unsafe"
)
func OpenBrowser(url string) {
@ -136,11 +130,6 @@ func IntMax(a int, b int) int {
}
}
func IsIP(s string) bool {
ip := net.ParseIP(s)
return ip != nil
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
@ -150,35 +139,33 @@ func GetUUID() string {
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.New(rand.NewSource(time.Now().UnixNano()))
rand.Seed(time.Now().UnixNano())
}
func GenerateRandomCharsKey(length int) (string, error) {
b := make([]byte, length)
maxI := big.NewInt(int64(len(keyChars)))
for i := range b {
n, err := crand.Int(crand.Reader, maxI)
if err != nil {
return "", err
}
b[i] = keyChars[n.Int64()]
}
return string(b), nil
}
func GenerateRandomKey(length int) (string, error) {
bytes := make([]byte, length*3/4) // 对于48位的输出这里应该是36
if _, err := crand.Read(bytes); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(bytes), nil
}
func GenerateKey() (string, error) {
func GenerateKey() string {
//rand.Seed(time.Now().UnixNano())
return GenerateRandomCharsKey(48)
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
//rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomInt(max int) int {
@ -203,64 +190,49 @@ func Max(a int, b int) int {
}
}
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func RandomSleep() {
// Sleep for 0-3000 ms
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}
func GetProxiedHttpClient(proxyUrl string) (*http.Client, error) {
if "" == proxyUrl {
return &http.Client{}, nil
}
u, err := url.Parse(proxyUrl)
if err != nil {
return nil, err
}
if strings.HasPrefix(proxyUrl, "http") {
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(u),
},
}, nil
} else if strings.HasPrefix(proxyUrl, "socks") {
dialer, err := proxy.FromURL(u, proxy.Direct)
if err != nil {
return nil, err
func StringsContains(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
return &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
},
},
}, nil
}
return nil, errors.New("unsupported proxy type")
return false
}
func ProxiedHttpGet(url, proxyUrl string) (*http.Response, error) {
client, err := GetProxiedHttpClient(proxyUrl)
if err != nil {
return nil, err
}
return client.Get(url)
}
func ProxiedHttpHead(url, proxyUrl string) (*http.Response, error) {
client, err := GetProxiedHttpClient(proxyUrl)
if err != nil {
return nil, err
}
return client.Head(url)
// []byte only read, panic on append
func StringToByteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}

View File

@ -1,35 +0,0 @@
package constant
import (
"encoding/json"
"one-api/common"
)
var Chats = []map[string]string{
{
"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
},
{
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
},
{
"AMA 问天": "ama://set-api-key?server={address}&key={key}",
},
{
"OpenCat": "opencat://team/join?domain={address}&token={key}",
},
}
func UpdateChatsByJsonString(jsonString string) error {
Chats = make([]map[string]string, 0)
return json.Unmarshal([]byte(jsonString), &Chats)
}
func Chats2JsonString() string {
jsonBytes, err := json.Marshal(Chats)
if err != nil {
common.SysError("error marshalling chats: " + err.Error())
return "[]"
}
return string(jsonBytes)
}

View File

@ -1,51 +0,0 @@
package constant
import (
"fmt"
"one-api/common"
"os"
"strings"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
// ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-pro-exp-0827": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash-exp-0827": "v1beta",
"gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
}
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
}
// 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

View File

@ -1,50 +0,0 @@
package constant
var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true
var MjActionCheckSuccessEnabled = true
const (
MjErrorUnknown = 5
MjRequestError = 4
)
const (
MjActionImagine = "IMAGINE"
MjActionDescribe = "DESCRIBE"
MjActionBlend = "BLEND"
MjActionUpscale = "UPSCALE"
MjActionVariation = "VARIATION"
MjActionReRoll = "REROLL"
MjActionInPaint = "INPAINT"
MjActionModal = "MODAL"
MjActionZoom = "ZOOM"
MjActionCustomZoom = "CUSTOM_ZOOM"
MjActionShorten = "SHORTEN"
MjActionHighVariation = "HIGH_VARIATION"
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
)
var MidjourneyModel2Action = map[string]string{
"mj_imagine": MjActionImagine,
"mj_describe": MjActionDescribe,
"mj_blend": MjActionBlend,
"mj_upscale": MjActionUpscale,
"mj_variation": MjActionVariation,
"mj_reroll": MjActionReRoll,
"mj_modal": MjActionModal,
"mj_inpaint": MjActionInPaint,
"mj_zoom": MjActionZoom,
"mj_custom_zoom": MjActionCustomZoom,
"mj_shorten": MjActionShorten,
"mj_high_variation": MjActionHighVariation,
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
}

View File

@ -1,8 +0,0 @@
package constant
var PayAddress = ""
var CustomCallbackAddress = ""
var EpayId = ""
var EpayKey = ""
var Price = 7.3
var MinTopUp = 1

View File

@ -1,43 +0,0 @@
package constant
import "strings"
var CheckSensitiveEnabled = true
var CheckSensitiveOnPromptEnabled = true
//var CheckSensitiveOnCompletionEnabled = true
// StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词
var StopOnSensitiveEnabled = true
// StreamCacheQueueLength 流模式缓存队列长度0表示无缓存
var StreamCacheQueueLength = 0
// SensitiveWords 敏感词
// var SensitiveWords []string
var SensitiveWords = []string{
"test_sensitive",
}
func SensitiveWordsToString() string {
return strings.Join(SensitiveWords, "\n")
}
func SensitiveWordsFromString(s string) {
SensitiveWords = []string{}
sw := strings.Split(s, "\n")
for _, w := range sw {
w = strings.TrimSpace(w)
if w != "" {
SensitiveWords = append(SensitiveWords, w)
}
}
}
func ShouldCheckPromptSensitive() bool {
return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled
}
//func ShouldCheckCompletionSensitive() bool {
// return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
//}

View File

@ -1,18 +0,0 @@
package constant
type TaskPlatform string
const (
TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj"
)
const (
SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS"
)
var SunoModel2Action = map[string]string{
"suno_music": SunoActionMusic,
"suno_lyrics": SunoActionLyrics,
}

View File

@ -3,7 +3,6 @@ package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/dto"
"one-api/model"
)
@ -28,7 +27,7 @@ func GetSubscription(c *gin.Context) {
expiredTime = 0
}
if err != nil {
openAIError := dto.OpenAIError{
openAIError := OpenAIError{
Message: err.Error(),
Type: "upstream_error",
}
@ -70,7 +69,7 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId)
}
if err != nil {
openAIError := dto.OpenAIError{
openAIError := OpenAIError{
Message: err.Error(),
Type: "new_api_error",
}

View File

@ -8,7 +8,6 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/service"
"strconv"
"time"
@ -93,7 +92,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
res, err := service.GetHttpClient().Do(req)
res, err := httpClient.Do(req)
if err != nil {
return nil, err
}
@ -214,8 +213,10 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
//case common.ChannelTypeOpenAISB:
// return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeCloseAI:
return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB:
return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
@ -309,7 +310,7 @@ func updateAllChannelsBalance() error {
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
service.DisableChannel(channel.Id, channel.Name, "余额不足")
disableChannel(channel.Id, channel.Name, "余额不足")
}
}
time.Sleep(common.RequestInterval)

View File

@ -5,164 +5,107 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
"net/http/httptest"
"net/url"
"one-api/common"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
tik := time.Now()
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
}
if channel.Type == common.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
Body: nil,
Header: make(http.Header),
}
if testModel == "" {
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0]
} else {
testModel = "gpt-3.5-turbo"
}
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
switch channel.Type {
case common.ChannelTypePaLM:
fallthrough
case common.ChannelTypeAnthropic:
fallthrough
case common.ChannelTypeBaidu:
fallthrough
case common.ChannelTypeZhipu:
fallthrough
case common.ChannelTypeAli:
fallthrough
case common.ChannelType360:
fallthrough
case common.ChannelTypeGemini:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:
if request.Model == "" {
request.Model = "gpt-35-turbo"
}
} else {
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
defer func() {
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
}
}()
default:
if request.Model == "" {
request.Model = "gpt-3.5-turbo"
}
}
baseUrl := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseUrl = channel.GetBaseURL()
}
requestURL := getFullRequestURL(baseUrl, "/v1/chat/completions", channel.Type)
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, testModel)
meta := relaycommon.GenRelayInfo(c)
apiType, _ := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
if channel.Type == common.ChannelTypeAzure {
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
}
request := buildTestRequest(testModel)
meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
adaptor.Init(meta)
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
jsonData, err := json.Marshal(request)
if err != nil {
return err, nil
}
jsonData, err := json.Marshal(convertedRequest)
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err, nil
}
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
return err, nil
}
if resp != nil && resp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), respErr
}
if usage == nil {
return errors.New("usage is nil"), nil
}
result := w.Result()
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
}
modelPrice, usePrice := common.GetModelPrice(testModel, false)
modelRatio := common.GetModelRatio(testModel)
completionRatio := common.GetCompletionRatio(testModel)
ratio := modelRatio
quota := 0
if !usePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
quota = int(modelPrice * common.QuotaPerUnit)
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return err, nil
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err, nil
}
if response.Usage.CompletionTokens == 0 {
if response.Error.Message == "" {
response.Error.Message = "补全 tokens 非预期返回 0"
}
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later
Stream: false,
}
if "o1" == model || strings.HasPrefix(model, "o1-") {
testRequest.MaxCompletionTokens = 1
} else {
testRequest.MaxTokens = 1
func buildTestRequest() *ChatRequest {
testRequest := &ChatRequest{
Model: "", // this will be set later
MaxTokens: 1,
}
content, _ := json.Marshal("hi")
testMessage := dto.Message{
testMessage := Message{
Role: "user",
Content: content,
}
testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func TestChannel(c *gin.Context) {
channelId, err := strconv.Atoi(c.Param("id"))
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -170,7 +113,8 @@ func TestChannel(c *gin.Context) {
})
return
}
channel, err := model.GetChannelById(channelId, true)
testModel := c.Param("model")
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -178,9 +122,12 @@ func TestChannel(c *gin.Context) {
})
return
}
testModel := c.Query("model")
testRequest := buildTestRequest()
if testModel != "" {
testRequest.Model = testModel
}
tik := time.Now()
err, _ = testChannel(channel, testModel)
err, _ = testChannel(channel, *testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
@ -204,6 +151,31 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}
func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
@ -219,42 +191,38 @@ func testAllChannels(notify bool) error {
if err != nil {
return err
}
testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
gopool.Go(func() {
go func() {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
err, openaiWithStatusErr := testChannel(channel, "")
err, openaiErr := testChannel(channel, *testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
shouldBanChannel := false
// request error disables the channel
if openaiWithStatusErr != nil {
oaiErr := openaiWithStatusErr.Error
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
}
ban := false
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
shouldBanChannel = true
ban = true
}
// disable channel
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
service.DisableChannel(channel.Id, channel.Name, err.Error())
if openaiErr != nil {
err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
ban = true
}
// enable channel
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) && ban {
disableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
@ -267,7 +235,7 @@ func testAllChannels(notify bool) error {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
})
}()
return nil
}

View File

@ -1,8 +1,6 @@
package controller
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
@ -11,34 +9,6 @@ import (
"strings"
)
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
@ -65,65 +35,6 @@ func GetAllChannels(c *gin.Context) {
return
}
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if channel.Type != common.ChannelTypeOpenAI {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "仅支持 OpenAI 类型渠道",
})
return
}
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
result := OpenAIModelsResponse{}
err = json.Unmarshal(body, &result)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
if !result.Success {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "上游返回错误",
})
}
var ids []string
for _, model := range result.Data {
ids = append(ids, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ids,
})
}
func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility()
if err != nil {
@ -143,9 +54,8 @@ func FixChannelsAbilities(c *gin.Context) {
func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
//idSort, _ := strconv.ParseBool(c.Query("id_sort"))
channels, err := model.SearchChannels(keyword, group, modelKeyword)
channels, err := model.SearchChannels(keyword, group)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -198,28 +108,6 @@ func AddChannel(c *gin.Context) {
}
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
keys = []string{channel.Key}
}
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
@ -319,27 +207,6 @@ func UpdateChannel(c *gin.Context) {
})
return
}
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
}
err = channel.Update()
if err != nil {
c.JSON(http.StatusOK, gin.H{

View File

@ -112,9 +112,7 @@ func GitHubOAuth(c *gin.Context) {
user := model.User{
GitHubId: githubUser.Login,
}
// IsGitHubIdAlreadyTaken is unscoped
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
// FillUserByGitHubId is scoped
err := user.FillUserByGitHubId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
@ -123,18 +121,8 @@ func GitHubOAuth(c *gin.Context) {
})
return
}
// if user.Id == 0 , user has been deleted
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
@ -145,7 +133,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(user.InviterId); err != nil {
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),

View File

@ -4,7 +4,6 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
)
func GetGroups(c *gin.Context) {
@ -18,22 +17,3 @@ func GetGroups(c *gin.Context) {
"data": groupNames,
})
}
func GetUserGroups(c *gin.Context) {
usableGroups := make(map[string]string)
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.CacheGetUserGroup(userId)
for groupName, _ := range common.GroupRatio {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := common.GetUserUsableGroups(userGroup)
if _, ok := userUsableGroups[groupName]; ok {
usableGroups[groupName] = userUsableGroups[groupName]
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": usableGroups,
})
}

View File

@ -1,247 +0,0 @@
package controller
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
type LinuxDoOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type LinuxDoUser struct {
ID int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
form := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
}
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", "Basic "+auth)
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse LinuxDoOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res2.Body.Close()
var linuxdoUser LinuxDoUser
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
if err != nil {
return nil, err
}
if linuxdoUser.ID == 0 {
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
}
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
return nil, errors.New("用户 LINUX DO 信任等级不足!")
}
return &linuxdoUser, nil
}
func LinuxDoOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
err := user.FillUserByLinuxDoId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
affCode := c.Query("aff")
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
if linuxdoUser.Name != "" {
user.DisplayName = linuxdoUser.Name
} else {
user.DisplayName = linuxdoUser.Username
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 LINUX DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@ -1,22 +1,17 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-gonic/gin"
)
func GetAllLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
if p < 0 {
p = 0
}
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
@ -25,7 +20,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel)
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -36,26 +31,15 @@ func GetAllLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": map[string]any{
"items": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
"data": logs,
})
return
}
func GetUserLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
if pageSize > 100 {
pageSize = 100
if p < 0 {
p = 0
}
userId := c.GetInt("id")
logType, _ := strconv.Atoi(c.Query("type"))
@ -63,7 +47,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize)
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -74,12 +58,7 @@ func GetUserLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": map[string]any{
"items": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
"data": logs,
})
return
}
@ -192,7 +171,7 @@ func DeleteHistoryLogs(c *gin.Context) {
})
return
}
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
count, err := model.DeleteOldLog(targetTimestamp)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@ -10,14 +10,143 @@ import (
"log"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/service"
"strconv"
"strings"
"time"
)
/*func UpdateMidjourneyTask() {
//revocer
//imageModel := "midjourney"
ctx := context.TODO()
imageModel := "midjourney"
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
for _, task := range tasks {
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
if err != nil {
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
task.FailReason = fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", task.ChannelId)
task.Status = "FAILURE"
task.Progress = "100%"
err := task.Update()
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
continue
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
//req.Header.Set("Authorization", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
continue
}
responseBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
log.Printf("responseBody: %s", string(responseBody))
var responseItem Midjourney
// err = json.NewDecoder(resp.Body).Decode(&responseItem)
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
if strings.Contains(err.Error(), "cannot unmarshal number into Go struct field Midjourney.status of type string") {
var responseWithoutStatus MidjourneyWithoutStatus
var responseStatus MidjourneyStatus
err1 := json.Unmarshal(responseBody, &responseWithoutStatus)
err2 := json.Unmarshal(responseBody, &responseStatus)
if err1 == nil && err2 == nil {
jsonData, err3 := json.Marshal(responseWithoutStatus)
if err3 != nil {
log.Printf("UpdateMidjourneyTask error1: %v", err3)
continue
}
err4 := json.Unmarshal(jsonData, &responseStatus)
if err4 != nil {
log.Printf("UpdateMidjourneyTask error2: %v", err4)
continue
}
responseItem.Status = strconv.Itoa(responseStatus.Status)
} else {
log.Printf("UpdateMidjourneyTask error3: %v", err)
continue
}
} else {
log.Printf("UpdateMidjourneyTask error4: %v", err)
continue
}
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
log.Println("error update user quota cache: " + err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 {
err := model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
log.Println("fail to increase user quota")
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
log.Printf("UpdateMidjourneyTask error5: %v", err)
}
log.Printf("UpdateMidjourneyTask success: %v", task)
cancel()
}
}
}
}
*/
func UpdateMidjourneyTaskBulk() {
//imageModel := "midjourney"
ctx := context.TODO()
@ -86,27 +215,23 @@ func UpdateMidjourneyTaskBulk() {
continue
}
// 设置超时时间
timeout := time.Second * 15
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := service.GetHttpClient().Do(req)
resp, err := httpClient.Do(req)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []dto.MidjourneyDto
var responseItems []Midjourney
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
@ -118,16 +243,10 @@ func UpdateMidjourneyTaskBulk() {
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
// 如果时间超过一小时且进度不是100%,则认为任务失败
if useTime > 3600000 && task.Progress != "100%" {
responseItem.FailReason = "上游任务超时超过1小时"
responseItem.Status = "FAILURE"
}
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
@ -138,41 +257,34 @@ func UpdateMidjourneyTaskBulk() {
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if responseItem.Properties != nil {
propertiesStr, _ := json.Marshal(responseItem.Properties)
task.Properties = string(propertiesStr)
}
if responseItem.Buttons != nil {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
if task.Quota != 0 {
shouldReturnQuota = true
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(task.Quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
}
}
}
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
if oldTask.Code != 1 {
return true
}
@ -231,12 +343,6 @@ func GetAllMidjourney(c *gin.Context) {
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if constant.MjForwardUrlEnabled {
for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
@ -263,7 +369,7 @@ func GetUserMidjourney(c *gin.Context) {
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if constant.MjForwardUrlEnabled {
if !strings.Contains(common.ServerAddress, "localhost") {
for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney

View File

@ -5,51 +5,28 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
func TestStatus(c *gin.Context) {
err := model.PingDB()
if err != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"success": false,
"message": "数据库连接失败",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Server is running",
})
return
}
func GetStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
"linuxdo_client_id": common.LinuxDoClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress,
"stripe_unit_price": common.StripeUnitPrice,
"min_topup": common.MinTopUp,
"price": common.Price,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
@ -59,13 +36,8 @@ func GetStatus(c *gin.Context) {
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"payment_enabled": common.PaymentEnabled,
"mj_notify_enabled": constant.MjNotifyEnabled,
"chats": constant.Chats,
},
})
return
@ -124,20 +96,10 @@ func SendEmailVerification(c *gin.Context) {
})
return
}
parts := strings.Split(email, "@")
if len(parts) != 2 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
localPart := parts[0]
domainPart := parts[1]
if common.EmailDomainRestrictionEnabled {
allowed := false
for _, domain := range common.EmailDomainWhitelist {
if domainPart == domain {
if strings.HasSuffix(email, "@"+domain) {
allowed = true
break
}
@ -145,22 +107,11 @@ func SendEmailVerification(c *gin.Context) {
if !allowed {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
})
return
}
}
if common.EmailAliasRestrictionEnabled {
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Contains(localPart, ".")
if containsSpecialSymbols {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
})
return
}
}
if model.IsEmailAlreadyTaken(email) {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@ -2,30 +2,43 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay"
"one-api/relay/channel/ai360"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
)
// https://platform.openai.com/docs/api-reference/models/list
var openAIModels []dto.OpenAIModels
var openAIModelsMap map[string]dto.OpenAIModels
var channelId2Models map[int][]string
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
func getPermission() []dto.OpenAIModelPermission {
var permission []dto.OpenAIModelPermission
permission = append(permission, dto.OpenAIModelPermission{
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
func init() {
var permission []OpenAIModelPermission
permission = append(permission, OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission",
Created: 1626777600,
@ -39,189 +52,595 @@ func getPermission() []dto.OpenAIModelPermission {
Group: nil,
IsBlocking: false,
})
return permission
}
func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
permission := getPermission()
for i := 0; i < relayconstant.APITypeDummy; i++ {
if i == relayconstant.APITypeAIProxyLibrary {
continue
}
adaptor := relay.GetAdaptor(i)
channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
for _, modelName := range ai360.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
openAIModels = []OpenAIModels{
{
Id: "midjourney",
Object: "model",
Created: 1626777600,
OwnedBy: ai360.ChannelName,
Created: 1677649963,
OwnedBy: "Midjourney",
Permission: permission,
Root: modelName,
Root: "midjourney",
Parent: nil,
})
}
for _, modelName := range moonshot.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
},
{
Id: "dall-e-2",
Object: "model",
Created: 1626777600,
OwnedBy: moonshot.ChannelName,
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: modelName,
Root: "dall-e-2",
Parent: nil,
})
}
for _, modelName := range lingyiwanwu.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
},
{
Id: "dall-e-3",
Object: "model",
Created: 1626777600,
OwnedBy: lingyiwanwu.ChannelName,
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: modelName,
Root: "dall-e-3",
Parent: nil,
})
}
for _, modelName := range minimax.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
},
{
Id: "whisper-1",
Object: "model",
Created: 1626777600,
OwnedBy: minimax.ChannelName,
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: modelName,
Root: "whisper-1",
Parent: nil,
})
}
for modelName, _ := range constant.MidjourneyModel2Action {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
},
{
Id: "tts-1",
Object: "model",
Created: 1626777600,
OwnedBy: "midjourney",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: modelName,
Root: "tts-1",
Parent: nil,
})
},
{
Id: "tts-1-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-1106",
Parent: nil,
},
{
Id: "tts-1-hd",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd",
Parent: nil,
},
{
Id: "tts-1-hd-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd-1106",
Parent: nil,
},
{
Id: "gpt-3.5-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0301",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0301",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-1106",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-1106",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0125",
Object: "model",
Created: 1706232090,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0125",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-instruct",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-instruct",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4",
Parent: nil,
},
{
Id: "gpt-4-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0314",
Parent: nil,
},
{
Id: "gpt-4-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0613",
Parent: nil,
},
{
Id: "gpt-4-32k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k",
Parent: nil,
},
{
Id: "gpt-4-32k-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-4-32k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0613",
Parent: nil,
},
{
Id: "gpt-4-1106-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-1106-preview",
Parent: nil,
},
{
Id: "gpt-4-0125-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0125-preview",
Parent: nil,
},
{
Id: "gpt-4-turbo-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-turbo-preview",
Parent: nil,
},
{
Id: "gpt-4-vision-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-vision-preview",
Parent: nil,
},
{
Id: "gpt-4-1106-vision-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-1106-vision-preview",
Parent: nil,
},
{
Id: "text-embedding-3-small",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
{
Id: "text-embedding-3-large",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
{
Id: "text-davinci-003",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-003",
Parent: nil,
},
{
Id: "text-davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-002",
Parent: nil,
},
{
Id: "text-curie-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-curie-001",
Parent: nil,
},
{
Id: "text-babbage-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-babbage-001",
Parent: nil,
},
{
Id: "text-ada-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-ada-001",
Parent: nil,
},
{
Id: "text-moderation-latest",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-latest",
Parent: nil,
},
{
Id: "text-moderation-stable",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-stable",
Parent: nil,
},
{
Id: "text-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-edit-001",
Parent: nil,
},
{
Id: "code-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "code-davinci-edit-001",
Parent: nil,
},
{
Id: "babbage-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "babbage-002",
Parent: nil,
},
{
Id: "davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "davinci-002",
Parent: nil,
},
{
Id: "claude-instant-1",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-instant-1",
Parent: nil,
},
{
Id: "claude-2",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2",
Parent: nil,
},
{
Id: "ERNIE-Bot",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot",
Parent: nil,
},
{
Id: "ERNIE-Bot-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "ERNIE-Bot-4",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-4",
Parent: nil,
},
{
Id: "Embedding-V1",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "Embedding-V1",
Parent: nil,
},
{
Id: "PaLM-2",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "PaLM-2",
Parent: nil,
},
{
Id: "gemini-pro",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro",
Parent: nil,
},
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_turbo",
Parent: nil,
},
{
Id: "chatglm_pro",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_pro",
Parent: nil,
},
{
Id: "chatglm_std",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_std",
Parent: nil,
},
{
Id: "chatglm_lite",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_lite",
Parent: nil,
},
{
Id: "qwen-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-turbo",
Parent: nil,
},
{
Id: "qwen-plus",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus",
Parent: nil,
},
{
Id: "text-embedding-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "text-embedding-v1",
Parent: nil,
},
{
Id: "SparkDesk",
Object: "model",
Created: 1677649963,
OwnedBy: "xunfei",
Permission: permission,
Root: "SparkDesk",
Parent: nil,
},
{
Id: "360GPT_S2_V9",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "360GPT_S2_V9",
Parent: nil,
},
{
Id: "embedding-bert-512-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding-bert-512-v1",
Parent: nil,
},
{
Id: "embedding_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "embedding_s1_v1",
Parent: nil,
},
{
Id: "semantic_similarity_s1_v1",
Object: "model",
Created: 1677649963,
OwnedBy: "360",
Permission: permission,
Root: "semantic_similarity_s1_v1",
Parent: nil,
},
{
Id: "hunyuan",
Object: "model",
Created: 1677649963,
OwnedBy: "tencent",
Permission: permission,
Root: "hunyuan",
Parent: nil,
},
}
openAIModelsMap = make(map[string]dto.OpenAIModels)
for _, aiModel := range openAIModels {
openAIModelsMap[aiModel.Id] = aiModel
}
channelId2Models = make(map[int][]string)
for i := 1; i <= common.ChannelTypeDummy; i++ {
apiType, success := relayconstant.ChannelType2APIType(i)
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
continue
}
meta := &relaycommon.RelayInfo{ChannelType: i}
adaptor := relay.GetAdaptor(apiType)
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
}
}
func ListModels(c *gin.Context) {
userOpenAiModels := make([]dto.OpenAIModels, 0)
permission := getPermission()
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else {
tokenModelLimit = map[string]bool{}
}
for allowModel, _ := range tokenModelLimit {
if _, ok := openAIModelsMap[allowModel]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: allowModel,
Parent: nil,
})
}
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "get user group failed",
})
return
}
group := userGroup
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
group = tokenGroup
}
models := model.GetGroupModels(group)
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: s,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: s,
Parent: nil,
})
}
}
}
c.JSON(200, gin.H{
"success": true,
"data": userOpenAiModels,
})
}
func ChannelListModels(c *gin.Context) {
c.JSON(200, gin.H{
"success": true,
"data": openAIModels,
})
}
func DashboardListModels(c *gin.Context) {
c.JSON(200, gin.H{
"success": true,
"data": channelId2Models,
"object": "list",
"data": openAIModels,
})
}
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if aiModel, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, aiModel)
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
} else {
openAIError := dto.OpenAIError{
openAIError := OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error",
Param: "model",

View File

@ -14,7 +14,7 @@ func GetOptions(c *gin.Context) {
var options []*model.Option
common.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue
}
options = append(options, &model.Option{
@ -50,14 +50,6 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "LinuxDoOAuthEnabled":
if option.Value == "true" && common.LinuxDoClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 LINUX DO OAuth请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret",
})
return
}
case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{

View File

@ -1,40 +0,0 @@
package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/model"
)
func GetPricing(c *gin.Context) {
pricing := model.GetPricing()
c.JSON(200, gin.H{
"success": true,
"data": pricing,
"group_ratio": common.GroupRatio,
})
}
func ResetModelRatio(c *gin.Context) {
defaultStr := common.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
err = common.UpdateModelRatioByJSONString(defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "重置模型倍率成功",
})
}

220
controller/relay-aiproxy.go Normal file
View File

@ -0,0 +1,220 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strconv"
"strings"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
type AIProxyLibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type AIProxyLibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type AIProxyLibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type AIProxyLibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []AIProxyLibraryDocument `json:"documents"`
AIProxyLibraryError
}
type AIProxyLibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []AIProxyLibraryDocument `json:"documents"`
}
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = string(request.Messages[len(request.Messages)-1].Content)
}
return &AIProxyLibraryRequest{
Model: request.Model,
Stream: request.Stream,
Query: query,
}
}
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
if len(documents) == 0 {
return ""
}
content := "\n\n参考文档\n"
for i, document := range documents {
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
}
return content
}
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents))
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
Id: common.GetUUID(),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &stopFinishReason
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
}
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &ChatCompletionsStreamResponse{
Id: common.GetUUID(),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: response.Model,
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
}
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
var documents []AIProxyLibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var AIProxyLibraryResponse AIProxyLibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

330
controller/relay-ali.go Normal file
View File

@ -0,0 +1,330 @@
package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: string(message.Content),
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = string(message.Content)
break
}
messages = append(messages, AliMessage{
User: string(message.Content),
Bot: string(request.Messages[i+1].Content),
})
i++
}
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
Prompt: prompt,
History: messages,
},
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP,
// TopK: 50,
// //Seed: 0,
// //EnableSearch: false,
//},
}
}
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
Object: "list",
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
content, _ := json.Marshal(response.Output.Text)
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

225
controller/relay-audio.go Normal file
View File

@ -0,0 +1,225 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
)
var availableVoices = []string{
"alloy",
"echo",
"fable",
"onyx",
"nova",
"shimmer",
}
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
var audioRequest AudioRequest
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err := common.UnmarshalBodyReusable(c, &audioRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
} else {
audioRequest = AudioRequest{
Model: "whisper-1",
}
}
//err := common.UnmarshalBodyReusable(c, &audioRequest)
// request validation
if audioRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
if strings.HasPrefix(audioRequest.Model, "tts-1") {
if audioRequest.Voice == "" {
return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
}
if !common.StringsContains(availableVoices, audioRequest.Voice) {
return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
}
}
preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiVersion := GetAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
}
requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
req.Header.Set("api-key", apiKey)
req.ContentLength = c.Request.ContentLength
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if resp.StatusCode != http.StatusOK {
return relayErrorHandler(resp)
}
var audioResponse AudioResponse
defer func(ctx context.Context) {
go func() {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
quota := 0
var promptTokens = 0
if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = countAudioToken(audioRequest.Input, audioRequest.Model)
promptTokens = quota
} else {
quota = countAudioToken(audioResponse.Text, audioRequest.Model)
}
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
if strings.HasPrefix(audioRequest.Model, "tts-1") {
} else {
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@ -1,4 +1,4 @@
package baidu
package controller
import (
"bufio"
@ -9,9 +9,6 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"sync"
"time"
@ -19,89 +16,149 @@ import (
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
type BaiduTokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type BaiduError struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
baiduRequest := BaiduChatRequest{
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
UserId: request.User,
}
if request.MaxTokens != 0 {
maxTokens := int(request.MaxTokens)
if request.MaxTokens == 1 {
maxTokens = 2
}
baiduRequest.MaxOutputTokens = &maxTokens
}
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
baiduRequest.System = message.StringContent()
messages = append(messages, BaiduMessage{
Role: "user",
Content: string(message.Content),
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Content: "Okay",
})
} else {
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
Content: string(message.Content),
})
}
}
return &baiduRequest
return &BaiduChatRequest{
Messages: messages,
Stream: request.Stream,
}
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
content, _ := json.Marshal(response.Result)
choice := dto.OpenAITextResponseChoice{
choice := OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := dto.OpenAITextResponse{
fullTextResponse := OpenAITextResponse{
Id: response.Id,
Object: "chat.completion",
Created: response.Created,
Choices: []dto.OpenAITextResponseChoice{choice},
Choices: []OpenAITextResponseChoice{choice},
Usage: response.Usage,
}
return &fullTextResponse
}
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(baiduResponse.Result)
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &relaycommon.StopFinishReason
choice.FinishReason = &stopFinishReason
}
response := dto.ChatCompletionsStreamResponse{
response := ChatCompletionsStreamResponse{
Id: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: "ernie-bot",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{
Input: request.ParseInput(),
}
}
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
@ -110,8 +167,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
return &openAIEmbeddingResponse
}
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage dto.Usage
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@ -138,7 +195,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@ -168,28 +225,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var baiduResponse BaiduChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@ -201,7 +258,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@ -209,23 +266,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
return nil, &fullTextResponse.Usage
}
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var baiduResponse BaiduEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
@ -237,7 +294,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@ -280,7 +337,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := service.GetImpatientHttpClient().Do(req)
res, err := impatientHTTPClient.Do(req)
if err != nil {
return nil, err
}

221
controller/relay-claude.go Normal file
View File

@ -0,0 +1,221 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample uint `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
}
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "max_tokens":
return "length"
default:
return reason
}
}
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: content,
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
}
return &fullTextResponse
}
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse)
response.Id = responseId
response.Created = createdTime
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if claudeResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
completionTokens := countTokenText(claudeResponse.Completion, model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

336
controller/relay-gemini.go Normal file
View File

@ -0,0 +1,336 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"strings"
"github.com/gin-gonic/gin"
)
const (
GeminiVisionMaxImageNum = 16
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: common.GeminiSafetySetting,
},
},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := GeminiChatContent{
Role: message.Role,
Parts: []GeminiPart{
{
Text: string(message.Content),
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
content.Role = "user"
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
}
func (g *GeminiChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
}
return ""
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
}
content, _ := json.Marshal("")
for i, candidate := range response.Candidates {
choice := OpenAITextResponseChoice{
Index: i,
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: stopFinishReason,
}
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = content
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &stopFinishReason
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "gemini-pro",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

230
controller/relay-image.go Normal file
View File

@ -0,0 +1,230 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
)
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
startTime := time.Now()
var imageRequest ImageRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
// Prompt validation
if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
}
if strings.Contains(imageRequest.Size, "×") {
return errorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
}
} else if imageRequest.Model == "dall-e-3" {
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
}
if imageRequest.N != 1 {
return errorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
}
}
// N should between 1 and 10
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := GetAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
}
var requestBody io.Reader
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 1
} else if imageRequest.Size == "512x512" {
sizeRatio = 1.125
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
sizeRatio = 2.5
}
qualityRatio := 1.0
if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
qualityRatio = 2.0
if imageRequest.Size == "1024×1792" || imageRequest.Size == "1792×1024" {
qualityRatio = 1.5
}
}
quota := int(ratio*sizeRatio*qualityRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
token := c.Request.Header.Get("Authorization")
if channelType == common.ChannelTypeAzure { // Azure authentication
token = strings.TrimPrefix(token, "Bearer ")
req.Header.Set("api-key", token)
} else {
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if resp.StatusCode != http.StatusOK {
return relayErrorHandler(resp)
}
var textResponse ImageResponse
defer func(ctx context.Context) {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
if consumeQuota {
if resp.StatusCode != http.StatusOK {
return
}
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

650
controller/relay-mj.go Normal file
View File

@ -0,0 +1,650 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
type Midjourney struct {
MjId string `json:"id"`
Action string `json:"action"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}
type MidjourneyWithoutStatus struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
}
var DefaultModelPrice = map[string]float64{
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_describe": 0.05,
"mj_upscale": 0.05,
}
func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByOnlyMJId(taskId)
if midjourneyTask == nil {
c.JSON(400, gin.H{
"error": "midjourney_task_not_found",
})
return
}
resp, err := http.Get(midjourneyTask.ImageUrl)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed",
})
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
c.JSON(resp.StatusCode, gin.H{
"error": string(responseBody),
})
return
}
// 从Content-Type头获取MIME类型
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
// 如果无法确定内容类型则默认为jpeg
contentType = "image/jpeg"
}
// 设置响应的内容类型
c.Writer.Header().Set("Content-Type", contentType)
// 将图片流式传输到响应体
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
log.Println("Failed to stream image:", err)
}
return
}
func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
var midjRequest Midjourney
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
Properties: nil,
Result: "",
}
}
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
if midjourneyTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "midjourney_task_not_found",
Properties: nil,
Result: "",
}
}
midjourneyTask.Progress = midjRequest.Progress
midjourneyTask.PromptEn = midjRequest.PromptEn
midjourneyTask.State = midjRequest.State
midjourneyTask.SubmitTime = midjRequest.SubmitTime
midjourneyTask.StartTime = midjRequest.StartTime
midjourneyTask.FinishTime = midjRequest.FinishTime
midjourneyTask.ImageUrl = midjRequest.ImageUrl
midjourneyTask.Status = midjRequest.Status
midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "update_midjourney_task_failed",
}
}
return nil
}
func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
midjourneyTask.State = originTask.State
midjourneyTask.SubmitTime = originTask.SubmitTime
midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" {
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
}
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
return
}
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
userId := c.GetInt("id")
var err error
var respBody []byte
switch relayMode {
case RelayModeMidjourneyTaskFetch:
taskId := c.Param("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
midjourneyTask := getMidjourneyTaskModel(c, originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
case RelayModeMidjourneyTaskFetchByCondition:
var condition = struct {
IDs []string `json:"ids"`
}{}
err = c.BindJSON(&condition)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
var tasks []Midjourney
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
midjourneyTask := getMidjourneyTaskModel(c, originTask)
tasks = append(tasks, midjourneyTask)
}
}
if tasks == nil {
tasks = make([]Midjourney, 0)
}
respBody, err = json.Marshal(tasks)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
}
c.Writer.Header().Set("Content-Type", "application/json")
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
return nil
}
const (
// type 1 根据 mode 价格不同
MJSubmitActionImagine = "IMAGINE"
MJSubmitActionVariation = "VARIATION" //变换
MJSubmitActionBlend = "BLEND" //混图
MJSubmitActionReroll = "REROLL" //重新生成
// type 2 固定价格
MJSubmitActionDescribe = "DESCRIBE"
MJSubmitActionUpscale = "UPSCALE" // 放大
)
func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
imageModel := "midjourney"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
var midjRequest MidjourneyRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
}
}
}
if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return &MidjourneyResponse{
Code: 4,
Description: "prompt_is_required",
}
}
midjRequest.Action = "IMAGINE"
} else if relayMode == RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = "DESCRIBE"
} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = "BLEND"
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := ""
if relayMode == RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return &MidjourneyResponse{
Code: 4,
Description: "taskId_is_required",
}
} else if midjRequest.Action == "" {
return &MidjourneyResponse{
Code: 4,
Description: "action_is_required",
}
} else if midjRequest.Index == 0 {
return &MidjourneyResponse{
Code: 4,
Description: "index_can_only_be_1_2_3_4",
}
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
return &MidjourneyResponse{
Code: 4,
Description: "content_is_required",
}
}
params := convertSimpleChangeParams(midjRequest.Content)
if params == nil {
return &MidjourneyResponse{
Code: 4,
Description: "content_parse_failed",
}
}
mjId = params.ID
midjRequest.Action = params.Action
}
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
} else if originTask.Action == "UPSCALE" {
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
return &MidjourneyResponse{
Code: 4,
Description: "upscale_task_can_not_be_change",
}
} else if originTask.Status != "SUCCESS" {
return &MidjourneyResponse{
Code: 4,
Description: "task_status_is_not_success",
}
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
channel, err := model.GetChannelById(originTask.ChannelId, false)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "channel_not_found",
}
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
log.Printf("检测到此操作为放大、变换获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_model_mapping_failed",
}
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
log.Printf("fullRequestURL: %s", fullRequestURL)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "marshal_text_request_failed",
}
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
modelPrice := common.GetModelPrice(mjAction, true)
// 如果没有配置价格,则使用默认价格
if modelPrice == -1 {
defaultPrice, ok := DefaultModelPrice[mjAction]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
groupRatio := common.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: err.Error(),
}
}
quota := int(ratio * common.QuotaPerUnit)
if consumeQuota && userQuota-quota < 0 {
return &MidjourneyResponse{
Code: 4,
Description: "quota_not_enough",
}
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "create_request_failed",
}
}
//req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//mjToken := ""
//if c.Request.Header.Get("Authorization") != "" {
// mjToken = strings.Split(c.Request.Header.Get("Authorization"), " ")[1]
//}
//req.Header.Set("Authorization", "Bearer midjourney-proxy")
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
// print request header
log.Printf("request header: %s", req.Header)
log.Printf("request body: %s", midjRequest.Prompt)
resp, err := httpClient.Do(req)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
err = req.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_request_body_failed",
}
}
err = c.Request.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_request_body_failed",
}
}
var midjResponse MidjourneyResponse
defer func(ctx context.Context) {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota, 0, false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}(c.Request.Context())
//if consumeQuota {
//
//}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "read_response_body_failed",
}
}
err = resp.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
err = json.Unmarshal(responseBody, &midjResponse)
log.Printf("responseBody: %s", string(responseBody))
log.Printf("midjResponse: %v", midjResponse)
if resp.StatusCode != 200 {
return &MidjourneyResponse{
Code: 4,
Description: "fail_to_fetch_midjourney status_code: " + strconv.Itoa(resp.StatusCode),
}
}
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
// 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
// 22-排队中 {"code":22,"description":"排队中前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
// other: 提交错误description为错误描述
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: midjRequest.Action,
MjId: midjResponse.Result,
Prompt: midjRequest.Prompt,
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
StartTime: 0,
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
midjourneyTask.FailReason = midjResponse.Description
consumeQuota = false
}
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
// 将 properties 转换为一个 map
properties, ok := midjResponse.Properties.(map[string]interface{})
if ok {
imageUrl, ok1 := properties["imageUrl"].(string)
status, ok2 := properties["status"].(string)
if ok1 && ok2 {
midjourneyTask.ImageUrl = imageUrl
midjourneyTask.Status = status
if status == "SUCCESS" {
midjourneyTask.Progress = "100%"
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
midjResponse.Code = 1
}
}
}
//修改返回值
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
responseBody = []byte(newBody)
}
err = midjourneyTask.Insert()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "insert_midjourney_task_failed",
}
}
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
//修改返回值
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
err = resp.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
return nil
}
type taskChangeParams struct {
ID string
Action string
Index int
}
func convertSimpleChangeParams(content string) *taskChangeParams {
split := strings.Split(content, " ")
if len(split) != 2 {
return nil
}
action := strings.ToLower(split[1])
changeParams := &taskChangeParams{}
changeParams.ID = split[0]
if action[0] == 'u' {
changeParams.Action = "UPSCALE"
} else if action[0] == 'v' {
changeParams.Action = "VARIATION"
} else if action == "r" {
changeParams.Action = "REROLL"
return changeParams
} else {
return nil
}
index, err := strconv.Atoi(action[1:2])
if err != nil || index < 1 || index > 4 {
return nil
}
changeParams.Index = index
return changeParams
}

162
controller/relay-openai.go Normal file
View File

@ -0,0 +1,162 @@
package controller
import (
"bufio"
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
"sync"
"time"
)
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
var responseTextBuilder strings.Builder
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
defer close(stopChan)
defer close(dataChan)
var wg sync.WaitGroup
go func() {
wg.Add(1)
defer wg.Done()
var streamItems []string
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data)
}
}
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case RelayModeChatCompletions:
var streamResponses []ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return // just ignore the error
}
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.Content)
}
}
case RelayModeCompletions:
var streamResponses []CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return // just ignore the error
}
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
if len(dataChan) > 0 {
// wait data out
time.Sleep(2 * time.Second)
}
common.SafeSend(stopChan, true)
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
wg.Wait()
return nil, responseTextBuilder.String()
}
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += countTokenText(string(choice.Message.Content), model)
}
textResponse.Usage = Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
return nil, &textResponse.Usage
}

View File

@ -1,4 +1,4 @@
package palm
package controller
import (
"encoding/json"
@ -7,15 +7,47 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK uint `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
}
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
@ -27,7 +59,7 @@ func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
Content: message.StringContent(),
Content: string(message.Content),
}
if message.Role == "user" {
palmMessage.Author = "0"
@ -39,15 +71,15 @@ func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
return &palmRequest
}
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
content, _ := json.Marshal(candidate.Content)
choice := dto.OpenAITextResponseChoice{
choice := OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Message: Message{
Role: "assistant",
Content: content,
},
@ -58,20 +90,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
return &fullTextResponse
}
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.SetContentString(palmResponse.Candidates[0].Content)
choice.Delta.Content = palmResponse.Candidates[0].Content
}
choice.FinishReason = &relaycommon.StopFinishReason
var response dto.ChatCompletionsStreamResponse
choice.FinishReason = &stopFinishReason
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "palm2"
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
@ -112,7 +144,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
dataChan <- string(jsonResponse)
stopChan <- true
}()
service.SetEventStreamHeaders(c)
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@ -125,28 +157,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
@ -156,8 +188,8 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
@ -165,7 +197,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)

288
controller/relay-tencent.go Normal file
View File

@ -0,0 +1,288 @@
package controller
import (
"bufio"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"sort"
"strconv"
"strings"
)
// https://cloud.tencent.com/document/product/1729/97732
type TencentMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type TencentUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: string(message.Content),
})
messages = append(messages, TencentMessage{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, TencentMessage{
Content: string(message.Content),
Role: message.Role,
})
}
stream := 0
if request.Stream {
stream = 1
}
return &TencentChatRequest{
Timestamp: common.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages,
}
}
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: response.Usage,
}
if len(response.Choices) > 0 {
content, _ := json.Marshal(response.Choices[0].Messages.Content)
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: response.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &stopFinishReason
}
response.Choices = append(response.Choices, choice)
}
return &response
}
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var TencentResponse TencentChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &TencentResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if TencentResponse.Error.Code != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
parts := strings.Split(config, "|")
if len(parts) != 3 {
err = errors.New("invalid tencent config")
return
}
appId, err = strconv.ParseInt(parts[0], 10, 64)
secretId = parts[1]
secretKey = parts[2]
return
}
func getTencentSign(req TencentChatRequest, secretKey string) string {
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
}

752
controller/relay-text.go Normal file
View File

@ -0,0 +1,752 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
)
const (
APITypeOpenAI = iota
APITypeClaude
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
)
var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
}
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := time.Now()
var textRequest GeneralOpenAIRequest
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
switch relayMode {
case RelayModeCompletions:
if textRequest.Prompt == "" {
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEmbeddings:
case RelayModeModerations:
if textRequest.Input == "" {
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEdits:
if textRequest.Instruction == "" {
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
isModelMapped = true
}
}
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeClaude
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
requestURL := strings.Split(requestURL, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
baseURL = c.GetString("base_url")
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := textRequest.Model
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
}
case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
}
case APITypeBaidu:
switch textRequest.Model {
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
var err error
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
}
fullRequestURL += "?access_token=" + apiKey
case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
requestBaseURL = baseURL
}
version := "v1beta"
if c.GetString("api_version") != "" {
version = c.GetString("api_version")
}
action := "generateContent"
if textRequest.Stream {
action = "streamGenerateContent"
}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
//log.Println(fullRequestURL)
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
case APITypeTencent:
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
}
var promptTokens int
var completionTokens int
switch relayMode {
case RelayModeChatCompletions:
promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model)
if err != nil {
return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
modelPrice := common.GetModelPrice(textRequest.Model, false)
groupRatio := common.GetGroupRatio(group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
if modelPrice == -1 {
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
}
modelRatio = common.GetModelRatio(textRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < 0 || userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if !tokenUnlimited {
// 非无限令牌,判断令牌额度是否充足
tokenQuota := c.GetInt("token_quota")
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", userId, userQuota, tokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
switch apiType {
case APITypeClaude:
claudeRequest := requestOpenAI2Claude(textRequest)
jsonStr, err := json.Marshal(claudeRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu:
var jsonData []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduEmbeddingRequest)
default:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonData, err = json.Marshal(baiduRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeGemini:
geminiChatRequest := requestOpenAI2Gemini(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
var jsonStr []byte
var err error
switch relayMode {
case RelayModeEmbeddings:
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliEmbeddingRequest)
default:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err = json.Marshal(aliRequest)
}
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeTencent:
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
if err != nil {
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
}
tencentRequest := requestOpenAI2Tencent(textRequest)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
jsonStr, err := json.Marshal(tencentRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
sign := getTencentSign(*tencentRequest, secretKey)
c.Request.Header.Set("Authorization", sign)
requestBody = bytes.NewBuffer(jsonStr)
case APITypeAIProxyLibrary:
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
var req *http.Request
var resp *http.Response
isStream := textRequest.Stream
if apiType != APITypeXunfei { // cause xunfei use websocket
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
// 设置GetBody函数该函数返回一个新的io.ReadCloser该io.ReadCloser返回与原始请求体相同的数据
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(requestBody), nil
}
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
if c.Request.Header.Get("OpenAI-Organization") != "" {
req.Header.Set("OpenAI-Organization", c.Request.Header.Get("OpenAI-Organization"))
}
if channelType == common.ChannelTypeOpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
}
case APITypeClaude:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
case APITypeGemini:
req.Header.Set("Content-Type", "application/json")
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
if apiType != APITypeGemini {
// 设置公共头部...
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
resp, err = httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
}
var textResponse TextResponse
tokenName := c.GetString("token_name")
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota := 0
if modelPrice == -1 {
completionRatio := common.GetCompletionRatio(textRequest.Model)
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := promptTokens + completionTokens
var logContent string
if modelPrice == -1 {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(有疑问请联系管理员)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s pre-consumed quota %d", userId, channelId, tokenId, textRequest.Model, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
logModel := textRequest.Model
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
}
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), isStream)
//if quota != 0 {
//
//}
}()
}(c.Request.Context())
switch apiType {
case APITypeOpenAI:
if isStream {
err, responseText := openaiStreamHandler(c, resp, relayMode)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeClaude:
if isStream {
err, responseText := claudeStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeBaidu:
if isStream {
err, usage := baiduStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, resp)
default:
err, usage = baiduHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypePaLM:
if textRequest.Stream { // PaLM2 API does not support stream
err, responseText := palmStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeGemini:
if textRequest.Stream {
err, responseText := geminiChatStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
} else {
err, usage := zhipuHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
// zhipu's API does not return prompt tokens & completion tokens
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
return nil
}
case APITypeAli:
if isStream {
err, usage := aliStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
var err *OpenAIErrorWithStatusCode
var usage *Usage
switch relayMode {
case RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeXunfei:
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
case APITypeAIProxyLibrary:
if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := aiProxyLibraryHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeTencent:
if isStream {
err, responseText := tencentStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := tencentHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}
}

330
controller/relay-utils.go Normal file
View File

@ -0,0 +1,330 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"log"
"math"
"net/http"
"one-api/common"
"strconv"
"strings"
"unicode/utf8"
)
var stopFinishReason = "stop"
// tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model, _ := range common.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder
} else {
tokenEncoderMap[model] = nil
}
}
common.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil {
return tokenEncoder
}
if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
return defaultTokenEncoder
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}
func getImageToken(imageUrl *MessageImageUrl) (int, error) {
if imageUrl.Detail == "low" {
return 85, nil
}
var config image.Config
var err error
var format string
if strings.HasPrefix(imageUrl.Url, "http") {
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
} else {
common.SysLog(fmt.Sprintf("decoding image"))
config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
}
if err != nil {
return 0, err
}
if config.Width == 0 || config.Height == 0 {
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url))
}
// TODO: 适配官方auto计费
if config.Width < 512 && config.Height < 512 {
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
// 如果图片尺寸小于512强制使用low
imageUrl.Detail = "low"
return 85, nil
}
}
shortSide := config.Width
otherSide := config.Height
log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
// 缩放倍数
scale := 1.0
if config.Height < shortSide {
shortSide = config.Height
otherSide = config.Width
}
// 将最小变的尺寸缩小到768以下如果大于768则缩放到768
if shortSide > 768 {
scale = float64(shortSide) / 768
shortSide = 768
}
// 将另一边按照相同的比例缩小,向上取整
otherSide = int(math.Ceil(float64(otherSide) / scale))
log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
// 计算图片的token数量(边的长度除以512向上取整)
tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
log.Printf("tiles: %d", tiles)
return tiles*170 + 85, nil
}
func countTokenMessages(messages []Message, model string) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 {
var arrayContent []MediaMessage
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
var stringContent string
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
return 0, err
} else {
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
} else {
for _, m := range arrayContent {
if m.Type == "image_url" {
var imageTokenNum int
if str, ok := m.ImageUrl.(string); ok {
imageTokenNum, err = getImageToken(&MessageImageUrl{Url: str, Detail: "auto"})
} else {
imageUrlMap := m.ImageUrl.(map[string]interface{})
detail, ok := imageUrlMap["detail"]
if ok {
imageUrlMap["detail"] = detail.(string)
} else {
imageUrlMap["detail"] = "auto"
}
imageUrl := MessageImageUrl{
Url: imageUrlMap["url"].(string),
Detail: imageUrlMap["detail"].(string),
}
imageTokenNum, err = getImageToken(&imageUrl)
}
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
}
}
}
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum, nil
}
func countTokenInput(input any, model string) int {
switch v := input.(type) {
case string:
return countTokenText(v, model)
case []string:
text := ""
for _, s := range v {
text += s
}
return countTokenText(text, model)
}
return 0
}
func countAudioToken(text string, model string) int {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text)
} else {
return countTokenText(text, model)
}
}
func countTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
}
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
text := err.Error()
// 定义一个正则表达式匹配URL
if strings.Contains(text, "Post") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
//避免暴露内部错误
openAIError := OpenAIError{
Message: text,
Type: "new_api_error",
Code: code,
}
return &OpenAIErrorWithStatusCode{
OpenAIError: openAIError,
StatusCode: statusCode,
}
}
func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" {
return true
}
return false
}
func shouldEnableChannel(err error, openAIErr *OpenAIError) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}
func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
OpenAIError: OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var textResponse TextResponse
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return
}
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
return
}
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}

View File

@ -1,4 +1,4 @@
package xunfei
package controller
import (
"crypto/hmac"
@ -12,9 +12,6 @@ import (
"net/http"
"net/url"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
@ -22,14 +19,69 @@ import (
// https://console.xfyun.cn/services/cbm
// https://www.xfyun.cn/doc/spark/Web.html
func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
type XunfeiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type XunfeiChatRequest struct {
Header struct {
AppId string `json:"app_id"`
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Message struct {
Text []XunfeiMessage `json:"text"`
} `json:"message"`
} `json:"payload"`
}
type XunfeiChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
}
type XunfeiChatResponse struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []XunfeiChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
messages := make([]XunfeiMessage, 0, len(request.Messages))
shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5")
for _, message := range request.Messages {
if message.Role == "system" && shouldCovertSystemMessage {
if message.Role == "system" {
messages = append(messages, XunfeiMessage{
Role: "user",
Content: message.StringContent(),
Content: string(message.Content),
})
messages = append(messages, XunfeiMessage{
Role: "assistant",
@ -38,7 +90,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string,
} else {
messages = append(messages, XunfeiMessage{
Role: message.Role,
Content: message.StringContent(),
Content: string(message.Content),
})
}
}
@ -52,7 +104,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string,
return &xunfeiRequest
}
func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse {
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
@ -61,24 +113,24 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse
}
}
content, _ := json.Marshal(response.Payload.Choices.Text[0].Content)
choice := dto.OpenAITextResponseChoice{
choice := OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: relaycommon.StopFinishReason,
FinishReason: stopFinishReason,
}
fullTextResponse := dto.OpenAITextResponse{
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []dto.OpenAITextResponseChoice{choice},
Choices: []OpenAITextResponseChoice{choice},
Usage: response.Payload.Usage.Text,
}
return &fullTextResponse
}
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse {
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
@ -86,16 +138,16 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCo
},
}
}
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content)
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &relaycommon.StopFinishReason
choice.FinishReason = &stopFinishReason
}
response := dto.ChatCompletionsStreamResponse{
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "SparkDesk",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
@ -126,14 +178,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
return callUrl
}
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
service.SetEventStreamHeaders(c)
var usage dto.Usage
setEventStreamHeaders(c)
var usage Usage
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
@ -156,13 +208,13 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
return nil, &usage
}
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
var usage dto.Usage
var usage Usage
var content string
var xunfeiResponse XunfeiChatResponse
stop := false
@ -179,26 +231,20 @@ func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId s
case stop = <-stopChan:
}
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
}
func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
@ -242,46 +288,20 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
return dataChan, stopChan, nil
}
func apiVersion2domain(apiVersion string) string {
switch apiVersion {
case "v1.1":
return "general"
case "v2.1":
return "generalv2"
case "v3.1":
return "generalv3"
case "v3.5":
return "generalv3.5"
case "v4.0":
return "4.0Ultra"
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
if apiVersion == "" {
apiVersion = "v1.1"
common.SysLog("api_version not found, use default: " + apiVersion)
}
domain := "general"
if apiVersion != "v1.1" {
domain += strings.Split(apiVersion, ".")[0]
}
return "general" + apiVersion
}
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
apiVersion := getAPIVersion(c, modelName)
domain := apiVersion2domain(apiVersion)
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl
}
func getAPIVersion(c *gin.Context, modelName string) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion != "" {
return apiVersion
}
parts := strings.Split(modelName, "-")
if len(parts) == 2 {
apiVersion = parts[1]
return apiVersion
}
apiVersion = c.GetString("api_version")
if apiVersion != "" {
return apiVersion
}
apiVersion = "v1.1"
common.SysLog("api_version not found, using default: " + apiVersion)
return apiVersion
}

View File

@ -1,4 +1,4 @@
package zhipu
package controller
import (
"bufio"
@ -8,9 +8,6 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"sync"
"time"
@ -21,6 +18,46 @@ import (
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
type ZhipuMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ZhipuRequest struct {
Prompt []ZhipuMessage `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}
type ZhipuResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []ZhipuMessage `json:"choices"`
Usage `json:"usage"`
}
type ZhipuResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
Data ZhipuResponseData `json:"data"`
}
type ZhipuStreamMetaResponse struct {
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
Usage `json:"usage"`
}
type zhipuTokenData struct {
Token string
ExpiryTime time.Time
}
var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600
@ -71,13 +108,13 @@ func getZhipuToken(apikey string) string {
return tokenString
}
func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest {
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
messages := make([]ZhipuMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, ZhipuMessage{
Role: "system",
Content: message.StringContent(),
Content: string(message.Content),
})
messages = append(messages, ZhipuMessage{
Role: "user",
@ -86,7 +123,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest {
} else {
messages = append(messages, ZhipuMessage{
Role: message.Role,
Content: message.StringContent(),
Content: string(message.Content),
})
}
}
@ -98,19 +135,19 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest {
}
}
func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Id: response.Data.TaskId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Data.Choices)),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
Usage: response.Data.Usage,
}
for i, choice := range response.Data.Choices {
content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
openaiChoice := dto.OpenAITextResponseChoice{
openaiChoice := OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Message: Message{
Role: choice.Role,
Content: content,
},
@ -124,36 +161,47 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
return &fullTextResponse
}
func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(zhipuResponse)
response := dto.ChatCompletionsStreamResponse{
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "chatglm",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString("")
choice.FinishReason = &relaycommon.StopFinishReason
response := dto.ChatCompletionsStreamResponse{
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = ""
choice.FinishReason = &stopFinishReason
response := ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "chatglm",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response, &zhipuResponse.Usage
}
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage *dto.Usage
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage *Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
metaChan := make(chan string)
stopChan := make(chan bool)
@ -177,7 +225,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
@ -212,28 +260,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var zhipuResponse ZhipuResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if !zhipuResponse.Success {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: zhipuResponse.Msg,
Type: "zhipu_error",
Param: "",
@ -245,7 +293,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)

View File

@ -1,261 +1,420 @@
package controller
import (
"bytes"
"errors"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/service"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
case relayconstant.RelayModeImagesGenerations:
err = relay.ImageHelper(c, relayMode)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:
fallthrough
case relayconstant.RelayModeAudioTranscription:
err = relay.AudioHelper(c)
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode)
default:
err = relay.TextHelper(c)
}
return err
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
}
func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode
type MediaMessage struct {
Type string `json:"type"`
Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"`
}
defer func() {
if openaiErr != nil {
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
type MessageImageUrl struct {
Url string `json:"url"`
Detail string `json:"detail"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: stringContent,
})
return contentList
}
var arrayContent []json.RawMessage
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "auto"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
}
}
}
}()
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
return
return contentList
}
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
return
}
return nil
}
if playgroundRequest.Model == "" {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
return
}
c.Set("original_model", playgroundRequest.Model)
group := playgroundRequest.Group
userGroup := c.GetString("group")
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeMidjourneyImagine
RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend
RelayModeMidjourneyChange
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
if group == "" {
group = userGroup
} else {
if !common.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
// https://platform.openai.com/docs/api-reference/chat
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
return input
}
type AudioRequest struct {
Model string `json:"model"`
Voice string `json:"voice"`
Input string `json:"input"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens uint `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens uint `json:"max_tokens"`
//Stream bool `json:"stream"`
}
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
}
type AudioResponse struct {
Text string `json:"text,omitempty"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
Relay(c)
}
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
}
func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
openaiErr = relayRequest(c, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
case RelayModeAudioSpeech:
fallthrough
case RelayModeAudioTranslation:
fallthrough
case RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
if retryCount == 0 {
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
if !autoBan {
autoBanInt = 0
}
return &model.Channel{
Id: c.GetInt("channel_id"),
Type: c.GetInt("channel_type"),
Name: c.GetString("channel_name"),
AutoBan: &autoBanInt,
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
if openaiErr == nil {
return false
}
if openaiErr.LocalError {
return false
}
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if openaiErr.StatusCode == http.StatusTooManyRequests {
return true
}
if openaiErr.StatusCode == 307 {
return true
}
if openaiErr.StatusCode/100 == 5 {
// 超时不重试
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
return false
requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
}
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
return true
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message))
} else {
if err.StatusCode == http.StatusTooManyRequests {
//err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
}
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
}
channelId := c.GetInt("channel_id")
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message)
}
return false
}
if openaiErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
if openaiErr.StatusCode/100 == 2 {
return false
}
return true
}
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(channelType, err) && autoBan {
service.DisableChannel(channelId, channelName, err.Error.Message)
}
}
func RelayMidjourney(c *gin.Context) {
relayMode := c.GetInt("relay_mode")
var err *dto.MidjourneyResponse
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/blend") {
relayMode = RelayModeMidjourneyBlend
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/describe") {
relayMode = RelayModeMidjourneyDescribe
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
relayMode = RelayModeMidjourneyNotify
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
relayMode = RelayModeMidjourneyTaskFetch
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
relayMode = RelayModeMidjourneyTaskFetchByCondition
}
var err *MidjourneyResponse
switch relayMode {
case relayconstant.RelayModeMidjourneyNotify:
err = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed:
err = relay.RelayMidjourneyTaskImageSeed(c)
case relayconstant.RelayModeSwapFace:
err = relay.RelaySwapFace(c)
case RelayModeMidjourneyNotify:
err = relayMidjourneyNotify(c)
case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
err = relayMidjourneyTask(c, relayMode)
default:
err = relay.RelayMidjourneySubmit(c, relayMode)
err = relayMidjourneySubmit(c, relayMode)
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
statusCode := http.StatusBadRequest
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
c.JSON(400, gin.H{
"error": err.Description + " " + err.Result,
})
}
c.JSON(statusCode, gin.H{
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
"type": "upstream_error",
"code": err.Code,
})
channelId := c.GetInt("channel_id")
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result))
//if shouldDisableChannel(&err.OpenAIError) {
// channelId := c.GetInt("channel_id")
// channelName := c.GetString("channel_name")
// disableChannel(channelId, channelName, err.Result)
//};''''''''''''''''''''''''''''''''
}
}
func RelayNotImplemented(c *gin.Context) {
err := dto.OpenAIError{
err := OpenAIError{
Message: "API not implemented",
Type: "new_api_error",
Param: "",
@ -267,7 +426,7 @@ func RelayNotImplemented(c *gin.Context) {
}
func RelayNotFound(c *gin.Context) {
err := dto.OpenAIError{
err := OpenAIError{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",
@ -277,94 +436,3 @@ func RelayNotFound(c *gin.Context) {
"error": err,
})
}
func RelayTask(c *gin.Context) {
retryTimes := common.RetryTimes
channelId := c.GetInt("channel_id")
relayMode := c.GetInt("relay_mode")
group := c.GetString("group")
originalModel := c.GetString("original_model")
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
taskErr := taskRelayHandler(c, relayMode)
if taskErr == nil {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
taskErr = taskRelayHandler(c, relayMode)
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
}
c.JSON(taskErr.StatusCode, taskErr)
}
}
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)
}
return err
}
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
if taskErr == nil {
return false
}
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if taskErr.StatusCode == http.StatusTooManyRequests {
return true
}
if taskErr.StatusCode == 307 {
return true
}
if taskErr.StatusCode/100 == 5 {
// 超时不重试
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
return false
}
return true
}
if taskErr.StatusCode == http.StatusBadRequest {
return false
}
if taskErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
if taskErr.LocalError {
return false
}
if taskErr.StatusCode/100 == 2 {
return false
}
return true
}

View File

@ -1,99 +0,0 @@
package controller
import (
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v79"
"github.com/stripe/stripe-go/v79/webhook"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
func StripeWebhook(c *gin.Context) {
payload, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
c.AbortWithStatus(http.StatusServiceUnavailable)
return
}
signature := c.GetHeader("Stripe-Signature")
endpointSecret := common.StripeWebhookSecret
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
IgnoreAPIVersionMismatch: true,
})
if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err)
c.AbortWithStatus(http.StatusBadRequest)
return
}
switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(event)
case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(event)
default:
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
}
c.Status(http.StatusOK)
}
func sessionCompleted(event stripe.Event) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "complete" != status {
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
return
}
err := model.Recharge(referenceId, customerId)
if err != nil {
log.Println(err.Error(), referenceId)
return
}
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency"))
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
}
func sessionExpired(event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "expired" != status {
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
return
}
if "" == referenceId {
log.Println("未提供支付单号")
return
}
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Println("充值订单不存在", referenceId)
return
}
if topUp.Status != common.TopUpStatusPending {
log.Println("充值订单状态错误", referenceId)
}
topUp.Status = common.TopUpStatusExpired
err := topUp.Update()
if err != nil {
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
return
}
log.Println("充值订单已过期", referenceId)
}

View File

@ -1,284 +0,0 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay"
"sort"
"strconv"
"time"
)
func UpdateTaskBulk() {
//revocer
//imageModel := "midjourney"
for {
time.Sleep(time.Duration(15) * time.Second)
common.SysLog("任务进度轮询开始")
ctx := context.TODO()
allTasks := model.GetAllUnFinishSyncTasks(500)
platformTask := make(map[constant.TaskPlatform][]*model.Task)
for _, t := range allTasks {
platformTask[t.Platform] = append(platformTask[t.Platform], t)
}
for platform, tasks := range platformTask {
if len(tasks) == 0 {
continue
}
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Task)
nullTaskIds := make([]int64, 0)
for _, task := range tasks {
if task.TaskID == "" {
// 统计失败的未完成任务
nullTaskIds = append(nullTaskIds, task.ID)
continue
}
taskM[task.TaskID] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
}
if len(nullTaskIds) > 0 {
err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
} else {
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
continue
}
UpdateTaskByPlatform(platform, taskChannelM, taskM)
}
common.SysLog("任务进度轮询完成")
}
}
func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
switch platform {
case constant.TaskPlatformMidjourney:
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
default:
common.SysLog("未知平台")
}
}
func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
if err != nil {
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
}
}
return nil
}
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
channel, err := model.CacheGetChannel(channelId)
if err != nil {
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
err = model.TaskBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
}
return err
}
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
if adaptor == nil {
return errors.New("adaptor not found")
}
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
"ids": taskIds,
})
if err != nil {
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
return err
}
if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
return err
}
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
return err
}
if !responseItems.IsSuccess() {
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
return err
}
for _, responseItem := range responseItems.Data {
task := taskM[responseItem.TaskID]
if !checkTaskNeedUpdate(task, responseItem) {
continue
}
task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("异步任务执行失败 %s补偿 %s", task.TaskID, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
if responseItem.Status == model.TaskStatusSuccess {
task.Progress = "100%"
}
task.Data = responseItem.Data
err = task.Update()
if err != nil {
common.SysError("UpdateMidjourneyTask task error: " + err.Error())
}
}
return nil
}
func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
if oldTask.SubmitTime != newTask.SubmitTime {
return true
}
if oldTask.StartTime != newTask.StartTime {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if string(oldTask.Status) != newTask.Status {
return true
}
if oldTask.FailReason != newTask.FailReason {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
return true
}
oldData, _ := json.Marshal(oldTask.Data)
newData, _ := json.Marshal(newTask.Data)
sort.Slice(oldData, func(i, j int) bool {
return oldData[i] < oldData[j]
})
sort.Slice(newData, func(i, j int) bool {
return newData[i] < newData[j]
})
if string(oldData) != string(newData) {
return true
}
return false
}
func GetAllTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
// 解析其他查询参数
queryParams := model.SyncTaskQueryParams{
Platform: constant.TaskPlatform(c.Query("platform")),
TaskID: c.Query("task_id"),
Status: c.Query("status"),
Action: c.Query("action"),
StartTimestamp: startTimestamp,
EndTimestamp: endTimestamp,
}
logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Task, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func GetUserTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
userId := c.GetInt("id")
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
queryParams := model.SyncTaskQueryParams{
Platform: constant.TaskPlatform(c.Query("platform")),
TaskID: c.Query("task_id"),
Status: c.Query("status"),
Action: c.Query("action"),
StartTimestamp: startTimestamp,
EndTimestamp: endTimestamp,
}
logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Task, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}

View File

@ -1,124 +0,0 @@
package controller
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"io"
"net/http"
"one-api/common"
"one-api/model"
"sort"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
func TelegramBind(c *gin.Context) {
if !common.TelegramOAuthEnabled {
c.JSON(200, gin.H{
"message": "管理员未开启通过 Telegram 登录以及注册",
"success": false,
})
return
}
params := c.Request.URL.Query()
if !checkTelegramAuthorization(params, common.TelegramBotToken) {
c.JSON(200, gin.H{
"message": "无效的请求",
"success": false,
})
return
}
telegramId := params["id"][0]
if model.IsTelegramIdAlreadyTaken(telegramId) {
c.JSON(200, gin.H{
"message": "该 Telegram 账户已被绑定",
"success": false,
})
return
}
session := sessions.Default(c)
id := session.Get("id")
user := model.User{Id: id.(int)}
if err := user.FillUserById(); err != nil {
c.JSON(200, gin.H{
"message": err.Error(),
"success": false,
})
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
user.TelegramId = telegramId
if err := user.Update(false); err != nil {
c.JSON(200, gin.H{
"message": err.Error(),
"success": false,
})
return
}
c.Redirect(302, "/setting")
}
func TelegramLogin(c *gin.Context) {
if !common.TelegramOAuthEnabled {
c.JSON(200, gin.H{
"message": "管理员未开启通过 Telegram 登录以及注册",
"success": false,
})
return
}
params := c.Request.URL.Query()
if !checkTelegramAuthorization(params, common.TelegramBotToken) {
c.JSON(200, gin.H{
"message": "无效的请求",
"success": false,
})
return
}
telegramId := params["id"][0]
user := model.User{TelegramId: telegramId}
if err := user.FillUserByTelegramId(); err != nil {
c.JSON(200, gin.H{
"message": err.Error(),
"success": false,
})
return
}
setupLogin(&user, c)
}
func checkTelegramAuthorization(params map[string][]string, token string) bool {
strs := []string{}
var hash = ""
for k, v := range params {
if k == "hash" {
hash = v[0]
continue
}
strs = append(strs, k+"="+v[0])
}
sort.Strings(strs)
var imploded = ""
for _, s := range strs {
if imploded != "" {
imploded += "\n"
}
imploded += s
}
sha256hash := sha256.New()
io.WriteString(sha256hash, token)
hmachash := hmac.New(sha256.New, sha256hash.Sum(nil))
io.WriteString(hmachash, imploded)
ss := hex.EncodeToString(hmachash.Sum(nil))
return hash == ss
}

View File

@ -123,19 +123,10 @@ func AddToken(c *gin.Context) {
})
return
}
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
Key: key,
Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
@ -143,8 +134,6 @@ func AddToken(c *gin.Context) {
UnlimitedQuota: token.UnlimitedQuota,
ModelLimitsEnabled: token.ModelLimitsEnabled,
ModelLimits: token.ModelLimits,
AllowIps: token.AllowIps,
Group: token.Group,
}
err = cleanToken.Insert()
if err != nil {
@ -232,8 +221,6 @@ func UpdateToken(c *gin.Context) {
cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
cleanToken.ModelLimits = token.ModelLimits
cleanToken.AllowIps = token.AllowIps
cleanToken.Group = token.Group
}
err = cleanToken.Update()
if err != nil {

View File

@ -1,20 +1,19 @@
package controller
import "C"
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v79"
"github.com/stripe/stripe-go/v79/checkout/session"
"github.com/samber/lo"
epay "github.com/star-horizon/go-epay"
"log"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"time"
)
type PayRequest struct {
type EpayRequest struct {
Amount int `json:"amount"`
PaymentMethod string `json:"payment_method"`
TopUpCode string `json:"top_up_code"`
@ -25,114 +24,142 @@ type AmountRequest struct {
TopUpCode string `json:"top_up_code"`
}
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
if !strings.HasPrefix(common.StripeApiSecret, "sk_") {
return "", fmt.Errorf("无效的Stripe API密钥")
func GetEpayClient() *epay.Client {
if common.PayAddress == "" || common.EpayId == "" || common.EpayKey == "" {
return nil
}
stripe.Key = common.StripeApiSecret
params := &stripe.CheckoutSessionParams{
ClientReferenceID: stripe.String(referenceId),
SuccessURL: stripe.String(common.ServerAddress + "/log"),
CancelURL: stripe.String(common.ServerAddress + "/topup"),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(common.StripePriceId),
Quantity: stripe.Int64(amount),
},
},
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
}
if "" == customerId {
if "" != email {
params.CustomerEmail = stripe.String(email)
}
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
} else {
params.Customer = stripe.String(customerId)
}
result, err := session.New(params)
withUrl, err := epay.NewClientWithUrl(&epay.Config{
PartnerID: common.EpayId,
Key: common.EpayKey,
}, common.PayAddress)
if err != nil {
return "", err
return nil
}
return result.URL, nil
return withUrl
}
func GetPayAmount(count float64) float64 {
return count * common.StripeUnitPrice
}
func GetChargedAmount(count float64, user model.User) float64 {
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
if topUpGroupRatio == 0 {
topUpGroupRatio = 1
func GetAmount(count float64, user model.User) float64 {
// 别问为什么用float64问就是这么点钱没必要
topupGroupRatio := common.GetTopupGroupRatio(user.Group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
return count * topUpGroupRatio
amount := count * common.Price * topupGroupRatio
return amount
}
func RequestPayLink(c *gin.Context) {
var req PayRequest
func RequestEpay(c *gin.Context) {
var req EpayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
return
}
if !common.PaymentEnabled {
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
return
}
if req.PaymentMethod != "stripe" {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
return
}
if req.Amount > 10000 {
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
if req.Amount < 1 {
c.JSON(200, gin.H{"message": "充值金额不能小于1", "data": 10})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
amount := GetAmount(float64(req.Amount), *user)
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), common.RandomString(4))
referenceId := "ref_" + common.Sha1(reference)
var payType epay.PurchaseType
if req.PaymentMethod == "zfb" {
payType = epay.Alipay
}
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = epay.WechatPay
}
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, int64(req.Amount))
returnUrl, _ := url.Parse(common.ServerAddress + "/log")
notifyUrl, _ := url.Parse(common.ServerAddress + "/api/user/epay/notify")
tradeNo := strconv.FormatInt(time.Now().Unix(), 10)
payMoney := amount
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType,
ServiceTradeNo: "A" + tradeNo,
Name: "B" + tradeNo,
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
Device: epay.PC,
NotifyUrl: notifyUrl,
ReturnUrl: returnUrl,
})
if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
topUp := &model.TopUp{
UserId: id,
Amount: req.Amount,
Money: chargedMoney,
TradeNo: referenceId,
Money: payMoney,
TradeNo: "A" + tradeNo,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
Status: "pending",
}
err = topUp.Insert()
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
return
}
c.JSON(200, gin.H{
"message": "success",
"data": gin.H{
"payLink": payLink,
},
})
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
}
func EpayNotify(c *gin.Context) {
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
client := GetEpayClient()
if client == nil {
log.Println("易支付回调失败 未找到配置信息")
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
}
}
verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus {
_, err := c.Writer.Write([]byte("success"))
if err != nil {
log.Println("易支付回调写入失败")
}
} else {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
}
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp.Status == "pending" {
topUp.Status = "success"
err := topUp.Update()
if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp)
return
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*500000), topUp.Money))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
}
}
func RequestAmount(c *gin.Context) {
@ -142,23 +169,12 @@ func RequestAmount(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return
}
if !common.PaymentEnabled {
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
return
}
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
if req.Amount < 1 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额不能小于1"})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
payMoney := GetPayAmount(float64(req.Amount))
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
c.JSON(200, gin.H{
"message": "success",
"data": gin.H{
"payAmount": strconv.FormatFloat(payMoney, 'f', 2, 64),
"chargedAmount": strconv.FormatFloat(chargedMoney, 'f', 2, 64),
},
})
payMoney := GetAmount(float64(req.Amount), *user)
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}

View File

@ -7,12 +7,9 @@ import (
"one-api/common"
"one-api/model"
"strconv"
"strings"
"sync"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"one-api/constant"
)
type LoginRequest struct {
@ -68,8 +65,6 @@ func setupLogin(user *model.User, c *gin.Context) {
session.Set("username", user.Username)
session.Set("role", user.Role)
session.Set("status", user.Status)
session.Set("group", user.Group)
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
err := session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
@ -161,9 +156,8 @@ func Register(c *gin.Context) {
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "数据库错误,请稍后重试",
"message": err.Error(),
})
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return
}
if exist {
@ -191,48 +185,6 @@ func Register(c *gin.Context) {
})
return
}
// 获取插入后的用户ID
var insertedUser model.User
if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户注册失败或用户ID获取失败",
})
return
}
// 生成默认令牌
if constant.GenerateDefaultToken {
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成默认令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
// 生成默认令牌
token := model.Token{
UserId: insertedUser.Id, // 使用插入后的用户ID
Name: cleanUser.Username + "的初始令牌",
Key: key,
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: -1, // 永不过期
RemainQuota: 500000, // 示例额度
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "创建默认令牌失败",
})
return
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@ -263,8 +215,7 @@ func GetAllUsers(c *gin.Context) {
func SearchUsers(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
users, err := model.SearchUsers(keyword, group)
users, err := model.SearchUsers(keyword)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -323,18 +274,7 @@ func GenerateAccessToken(c *gin.Context) {
})
return
}
// get rand int 28-32
randI := common.GetRandomInt(4)
key, err := common.GenerateRandomKey(29 + randI)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成失败",
})
common.SysError("failed to generate key: " + err.Error())
return
}
user.SetAccessToken(key)
user.AccessToken = common.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{
@ -511,7 +451,7 @@ func UpdateUser(c *gin.Context) {
updatedUser.Password = "" // rollback to what it should be
}
updatePassword := updatedUser.Password != ""
if err := updatedUser.Edit(updatePassword); err != nil {
if err := updatedUser.Update(updatePassword); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@ -575,7 +515,7 @@ func UpdateSelf(c *gin.Context) {
return
}
func HardDeleteUser(c *gin.Context) {
func DeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
@ -584,7 +524,7 @@ func HardDeleteUser(c *gin.Context) {
})
return
}
originUser, err := model.GetUserByIdUnscoped(id, false)
originUser, err := model.GetUserById(id, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -608,23 +548,9 @@ func HardDeleteUser(c *gin.Context) {
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteSelf(c *gin.Context) {
if !common.UserSelfDeletionEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "当前设置不允许用户自我删除账号",
})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
@ -654,7 +580,6 @@ func DeleteSelf(c *gin.Context) {
func CreateUser(c *gin.Context) {
var user model.User
err := json.NewDecoder(c.Request.Body).Decode(&user)
user.Username = strings.TrimSpace(user.Username)
if err != nil || user.Username == "" || user.Password == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -702,8 +627,8 @@ func CreateUser(c *gin.Context) {
}
type ManageRequest struct {
Id int `json:"id"`
Action string `json:"action"`
Username string `json:"username"`
Action string `json:"action"`
}
// ManageUser Only admin user can do this
@ -719,7 +644,7 @@ func ManageUser(c *gin.Context) {
return
}
user := model.User{
Id: req.Id,
Username: req.Username,
}
// Fill attributes
model.DB.Unscoped().Where(&user).First(&user)
@ -758,7 +683,7 @@ func ManageUser(c *gin.Context) {
})
return
}
if err := user.Delete(); err != nil {
if err := user.HardDelete(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@ -864,11 +789,7 @@ type topUpRequest struct {
Key string `json:"key"`
}
var topUpLock = sync.Mutex{}
func TopUp(c *gin.Context) {
topUpLock.Lock()
defer topUpLock.Unlock()
req := topUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {

View File

@ -78,13 +78,6 @@ func WeChatAuth(c *gin.Context) {
})
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)

View File

@ -2,17 +2,17 @@ version: '3.4'
services:
new-api:
image: pengzhile/new-api:latest
image: calciumion/new-api:latest
container_name: new-api
restart: always
command: --log-dir /app/logs
ports:
- "3000:3000"
volumes:
- ./data/new-api:/data
- ./data:/data
- ./logs:/app/logs
environment:
- SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai
@ -22,22 +22,13 @@ services:
depends_on:
- redis
- db
healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
timeout: 10s
retries: 3
redis:
image: redis:7.4
image: redis:latest
container_name: redis
restart: always
db:
image: mysql:8.2
container_name: mysql
restart: always
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: newapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: new-api # 自动创建数据库

View File

@ -1,34 +0,0 @@
package dto
type AudioRequest struct {
Model string `json:"model"`
Input string `json:"input"`
Voice string `json:"voice"`
Speed float64 `json:"speed,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
type AudioResponse struct {
Text string `json:"text"`
}
type WhisperVerboseJSONResponse struct {
Task string `json:"task,omitempty"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
}
type Segment struct {
Id int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogprob float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
}

View File

@ -1,22 +0,0 @@
package dto
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
}
type ImageResponse struct {
Data []ImageData `json:"data"`
Created int64 `json:"created"`
}
type ImageData struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
}

View File

@ -1,55 +0,0 @@
package dto
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
Error OpenAIError `json:"error"`
StatusCode int `json:"status_code"`
LocalError bool
}
type GeneralErrorResponse struct {
Error OpenAIError `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}

View File

@ -1,101 +0,0 @@
package dto
//type SimpleMjRequest struct {
// Prompt string `json:"prompt"`
// CustomId string `json:"customId"`
// Action string `json:"action"`
// Content string `json:"content"`
//}
type SwapFaceRequest struct {
SourceBase64 string `json:"sourceBase64"`
TargetBase64 string `json:"targetBase64"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
MaskBase64 string `json:"maskBase64"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
}
type MidjourneyUploadResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result []string `json:"result"`
}
type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"`
Response MidjourneyResponse
}
type MidjourneyDto struct {
MjId string `json:"id"`
Action string `json:"action"`
CustomId string `json:"customId"`
BotType string `json:"botType"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
Buttons any `json:"buttons"`
MaskBase64 string `json:"maskBase64"`
Properties *Properties `json:"properties"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}
type MidjourneyWithoutStatus struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
}
type ActionButton struct {
CustomId any `json:"customId"`
Emoji any `json:"emoji"`
Label any `json:"label"`
Type any `json:"type"`
Style any `json:"style"`
}
type Properties struct {
FinalPrompt string `json:"finalPrompt"`
FinalZhPrompt string `json:"finalZhPrompt"`
}

View File

@ -1,6 +0,0 @@
package dto
type PlayGroundRequest struct {
Model string `json:"model,omitempty"`
Group string `json:"group,omitempty"`
}

View File

@ -1,26 +0,0 @@
package dto
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}

View File

@ -1,22 +0,0 @@
package dto
type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`
}
type RerankResponseDocument struct {
Document any `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type RerankResponse struct {
Results []RerankResponseDocument `json:"results"`
Usage Usage `json:"usage"`
}

View File

@ -1,6 +0,0 @@
package dto
type SensitiveResponse struct {
SensitiveWords []string `json:"sensitive_words"`
Content string `json:"content"`
}

View File

@ -1,129 +0,0 @@
package dto
import (
"encoding/json"
)
type TaskData interface {
SunoDataResponse | []SunoDataResponse | string | any
}
type SunoSubmitReq struct {
GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"`
Prompt string `json:"prompt,omitempty"`
Mv string `json:"mv,omitempty"`
Title string `json:"title,omitempty"`
Tags string `json:"tags,omitempty"`
ContinueAt float64 `json:"continue_at,omitempty"`
TaskID string `json:"task_id,omitempty"`
ContinueClipId string `json:"continue_clip_id,omitempty"`
MakeInstrumental bool `json:"make_instrumental"`
}
type FetchReq struct {
IDs []string `json:"ids"`
}
type SunoDataResponse struct {
TaskID string `json:"task_id" gorm:"type:varchar(50);index"`
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
Status string `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed
FailReason string `json:"fail_reason"`
SubmitTime int64 `json:"submit_time" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
Data json.RawMessage `json:"data" gorm:"type:json"`
}
type SunoSong struct {
ID string `json:"id"`
VideoURL string `json:"video_url"`
AudioURL string `json:"audio_url"`
ImageURL string `json:"image_url"`
ImageLargeURL string `json:"image_large_url"`
MajorModelVersion string `json:"major_model_version"`
ModelName string `json:"model_name"`
Status string `json:"status"`
Title string `json:"title"`
Text string `json:"text"`
Metadata SunoMetadata `json:"metadata"`
}
type SunoMetadata struct {
Tags string `json:"tags"`
Prompt string `json:"prompt"`
GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"`
AudioPromptID interface{} `json:"audio_prompt_id"`
Duration interface{} `json:"duration"`
ErrorType interface{} `json:"error_type"`
ErrorMessage interface{} `json:"error_message"`
}
type SunoLyrics struct {
ID string `json:"id"`
Status string `json:"status"`
Title string `json:"title"`
Text string `json:"text"`
}
const TaskSuccessCode = "success"
type TaskResponse[T TaskData] struct {
Code string `json:"code"`
Message string `json:"message"`
Data T `json:"data"`
}
func (t *TaskResponse[T]) IsSuccess() bool {
return t.Code == TaskSuccessCode
}
type TaskDto struct {
TaskID string `json:"task_id"` // 第三方id不一定有/ song id\ Task id
Action string `json:"action"` // 任务类型, song, lyrics, description-mode
Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed
FailReason string `json:"fail_reason"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
Progress string `json:"progress"`
Data json.RawMessage `json:"data"`
}
type SunoGoAPISubmitReq struct {
CustomMode bool `json:"custom_mode"`
Input SunoGoAPISubmitReqInput `json:"input"`
NotifyHook string `json:"notify_hook,omitempty"`
}
type SunoGoAPISubmitReqInput struct {
GptDescriptionPrompt string `json:"gpt_description_prompt"`
Prompt string `json:"prompt"`
Mv string `json:"mv"`
Title string `json:"title"`
Tags string `json:"tags"`
ContinueAt float64 `json:"continue_at"`
TaskID string `json:"task_id"`
ContinueClipId string `json:"continue_clip_id"`
MakeInstrumental bool `json:"make_instrumental"`
}
type GoAPITaskResponse[T any] struct {
Code int `json:"code"`
Message string `json:"message"`
Data T `json:"data"`
ErrorMessage string `json:"error_message,omitempty"`
}
type GoAPITaskResponseData struct {
TaskID string `json:"task_id"`
}
type GoAPIFetchResponseData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
Input string `json:"input"`
Clips map[string]SunoSong `json:"clips"`
}

View File

@ -1,10 +0,0 @@
package dto
type TaskError struct {
Code string `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
StatusCode int `json:"-"`
LocalError bool `json:"-"`
Error error `json:"-"`
}

View File

@ -1,188 +0,0 @@
package dto
import "encoding/json"
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
BestOf int `json:"best_of,omitempty"`
Echo bool `json:"echo,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Suffix string `json:"suffix,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogitBias any `json:"logit_bias,omitempty"`
LogProbs any `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"`
}
type OpenAITools struct {
Type string `json:"type"`
Function OpenAIFunction `json:"function"`
}
type OpenAIFunction struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"`
}
type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int {
return int(r.MaxTokens)
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}
type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
}
type MediaMessage struct {
Type string `json:"type"`
Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"`
}
type MessageImageUrl struct {
Url string `json:"url"`
Detail string `json:"detail"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
func (m Message) StringContent() string {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return stringContent
}
return string(m.Content)
}
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
}
func (m Message) IsStringContent() bool {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return true
}
return false
}
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: stringContent,
})
return contentList
}
var arrayContent []json.RawMessage
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "high"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
} else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: "high",
},
})
}
}
}
return contentList
}
return nil
}

View File

@ -1,135 +0,0 @@
package dto
type TextResponseWithError struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type SimpleResponse struct {
Usage `json:"usage"`
Error OpenAIError `json:"error"`
Choices []OpenAITextResponseChoice `json:"choices"`
}
type TextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Model string `json:"model"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding any `json:"embedding"`
}
type OpenAIEmbeddingResponse struct {
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ChatCompletionsStreamResponseChoice struct {
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
Logprobs *any `json:"logprobs"`
FinishReason *string `json:"finish_reason"`
Index int `json:"index"`
}
type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
c.Content = &s
}
func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
if c.Content == nil {
return ""
}
return *c.Content
}
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id"`
Type any `json:"type"`
Function FunctionCall `json:"function"`
}
type FunctionCall struct {
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
// call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint *string `json:"system_fingerprint"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil {
return ""
}
return *c.SystemFingerprint
}
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
c.SystemFingerprint = &s
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

66
go.mod
View File

@ -1,33 +1,25 @@
module one-api
// +heroku goVersion go1.18
go 1.21
go 1.18
require (
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/gin-contrib/cors v1.6.0
github.com/chai2010/webp v1.1.1
github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.19.0
github.com/go-playground/validator/v10 v10.16.0
github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.4.0
github.com/pkg/errors v0.9.1
github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0
github.com/pkoukk/tiktoken-go v0.1.6
github.com/samber/lo v1.38.1
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/stripe/stripe-go/v79 v79.12.0
golang.org/x/crypto v0.31.0
golang.org/x/image v0.18.0
golang.org/x/net v0.33.0
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.17.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
@ -35,25 +27,18 @@ require (
)
require (
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/sonic v1.11.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect
github.com/chenzhuoyu/iasm v0.9.1 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
@ -64,24 +49,25 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.1.1 // indirect
github.com/stretchr/testify v1.9.0 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.7.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

161
go.sum
View File

@ -1,50 +1,26 @@
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM=
github.com/bytedance/sonic v1.11.2 h1:ywfwo0a/3j9HR8wsYGWsIWl2mvRsI950HyoxiBERw5A=
github.com/bytedance/sonic v1.11.2/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk=
github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA=
github.com/chenzhuoyu/iasm v0.9.0/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog=
github.com/chenzhuoyu/iasm v0.9.1 h1:tUHQJXo3NhBqw6s33wkGn9SP3bvrWLdlVIJ3hQBL7P0=
github.com/chenzhuoyu/iasm v0.9.1/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.6.0 h1:0Z7D/bVhE6ja07lI8CTjTonp6SB07o8bNuFyRbsBUQg=
github.com/gin-contrib/cors v1.6.0/go.mod h1:cI+h6iOAyxKRtUtC6iF/Si1KSFvGm/gK+kshxlCi8ro=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk=
github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE=
@ -61,7 +37,6 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@ -72,8 +47,10 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-playground/validator/v10 v10.19.0 h1:ol+5Fu+cSq9JD7SoSqe04GMI92cbn0+wvQ3bZ8b/AU4=
github.com/go-playground/validator/v10 v10.19.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
@ -85,12 +62,11 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
@ -103,12 +79,12 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@ -118,9 +94,8 @@ github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
@ -131,15 +106,19 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -147,28 +126,25 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI=
github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2/go.mod h1:SiffGCWGGMVwujne2dUQbJ5zUVD1V1Yj0hDuTfqFNEo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@ -179,11 +155,9 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stripe/stripe-go/v79 v79.12.0 h1:HQs/kxNEB3gYA7FnkSFkp0kSOeez0fsmCWev6SxftYs=
github.com/stripe/stripe-go/v79 v79.12.0/go.mod h1:cuH6X0zC8peY6f1AubHwgJ/fJSn2dh5pfiCr6CjyKVU=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@ -194,60 +168,60 @@ github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVM
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc=
golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
@ -266,5 +240,4 @@ gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

27
main.go
View File

@ -3,29 +3,26 @@ package main
import (
"embed"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"log"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/controller"
"one-api/middleware"
"one-api/model"
"one-api/router"
"one-api/service"
"os"
"strconv"
_ "net/http/pprof"
)
//go:embed web/dist
//go:embed web/build
var buildFS embed.FS
//go:embed web/dist/index.html
//go:embed web/build/index.html
var indexPage []byte
func main() {
@ -42,11 +39,6 @@ func main() {
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
defer func() {
err := model.CloseDB()
if err != nil {
@ -60,8 +52,6 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize constants
constant.InitEnv()
// Initialize options
model.InitOptionMap()
if common.RedisEnabled {
@ -98,14 +88,9 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
if common.IsMasterNode && constant.UpdateTask {
gopool.Go(func() {
controller.UpdateMidjourneyTaskBulk()
})
gopool.Go(func() {
controller.UpdateTaskBulk()
})
}
common.SafeGoroutine(func() {
controller.UpdateMidjourneyTaskBulk()
})
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
@ -120,7 +105,7 @@ func main() {
common.SysLog("pprof enabled")
}
service.InitTokenEncoders()
controller.InitTokenEncoders()
// Initialize HTTP server
server := gin.New()

View File

@ -6,29 +6,15 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
func validUserInfo(username string, role int) bool {
// check username is empty
if strings.TrimSpace(username) == "" {
return false
}
if !common.IsValidateRole(role) {
return false
}
return true
}
func authHelper(c *gin.Context, minRole int) {
session := sessions.Default(c)
username := session.Get("username")
role := session.Get("role")
id := session.Get("id")
status := session.Get("status")
linuxDoEnable := session.Get("linuxdo_enable")
useAccessToken := false
if username == nil {
// Check access token
accessToken := c.Request.Header.Get("Authorization")
@ -42,21 +28,11 @@ func authHelper(c *gin.Context, minRole int) {
}
user := model.ValidateAccessToken(accessToken)
if user != nil && user.Username != "" {
if !validUserInfo(user.Username, user.Role) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
// Token is valid
username = user.Username
role = user.Role
id = user.Id
status = user.Status
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
useAccessToken = true
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -66,36 +42,6 @@ func authHelper(c *gin.Context, minRole int) {
return
}
}
if !useAccessToken {
// get header New-Api-User
apiUserIdStr := c.Request.Header.Get("New-Api-User")
if apiUserIdStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,请刷新页面或清空缓存后重试",
})
c.Abort()
return
}
apiUserId, err := strconv.Atoi(apiUserIdStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,登录信息无效,请重新登录",
})
c.Abort()
return
}
if id != apiUserId {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,与登录用户不匹配,请重新登录",
})
c.Abort()
return
}
}
if status.(int) == common.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -104,14 +50,6 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort()
return
}
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户 LINUX DO 信任等级不足",
})
c.Abort()
return
}
if role.(int) < minRole {
c.JSON(http.StatusOK, gin.H{
"success": false,
@ -120,33 +58,12 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort()
return
}
if !validUserInfo(username.(string), role.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
c.Set("username", username)
c.Set("role", role)
c.Set("id", id)
c.Set("group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
c.Next()
}
func TryUserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
session := sessions.Default(c)
id := session.Get("id")
if id != nil {
c.Set("id", id)
}
c.Next()
}
}
func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleCommonUser)
@ -182,32 +99,17 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0]
}
token, err := model.ValidateUserToken(key)
if token != nil {
id := c.GetInt("id")
if id == 0 {
c.Set("id", token.UserId)
}
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !userEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !linuxDoEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
c.Set("id", token.UserId)
@ -223,13 +125,17 @@ func TokenAuth() func(c *gin.Context) {
} else {
c.Set("token_model_limit_enabled", false)
}
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1])
c.Set("channelId", parts[1])
} else {
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
}
}

View File

@ -1,15 +1,10 @@
package middleware
import (
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relayconstant "one-api/relay/constant"
"one-api/service"
"strconv"
"strings"
@ -22,55 +17,64 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
allowIpsMap := c.GetStringMap("allow_ips")
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
userId := c.GetInt("id")
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup, _ := model.CacheGetUserGroup(userId)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if _, ok := common.GroupRatio[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
userGroup = tokenGroup
}
c.Set("group", userGroup)
channelId, ok := c.Get("channelId")
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
if channel.Status != common.ChannelStatusEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
return
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
var err error
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
// Midjourney
if modelRequest.Model == "" {
modelRequest.Model = "midjourney"
}
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求: "+err.Error())
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
}
}
// check token model mapping
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
@ -83,153 +87,58 @@ func Distribute() func(c *gin.Context) {
}
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
// token model limit is empty, all models are not allowed
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
}
if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
// 如果错误,而且渠道为空,说明是没有可用渠道
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
return
}
if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
// 如果错误,而且渠道为空,说明是没有可用渠道
abortWithMessage(c, http.StatusServiceUnavailable, message)
return
}
if channel == nil {
abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
}
}
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
}
c.Next()
}
}
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
var modelRequest ModelRequest
shouldSelectChannel := true
var err error
if strings.Contains(c.Request.URL.Path, "/mj/") {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description)
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else {
// task fetch, task fetch by condition, notify
shouldSelectChannel = false
}
}
modelRequest.Model = midjourneyModel
}
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/suno/") {
relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeSunoFetch ||
relayMode == relayconstant.RelayModeSunoFetchByID {
shouldSelectChannel = false
} else {
modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
modelRequest.Model = modelName
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error())
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode := relayconstant.RelayModeAudioSpeech
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranslation
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranscription
}
c.Set("relay_mode", relayMode)
}
return &modelRequest, shouldSelectChannel, nil
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("original_model", modelName) // for retry
if channel == nil {
return
}
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeVertexAi:
c.Set("region", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
case common.ChannelCloudflare:
c.Set("api_version", channel.Other)
}
}

View File

@ -1,13 +1,11 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
)
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
userId := c.GetInt("id")
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
@ -15,15 +13,5 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
},
})
c.Abort()
common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
}
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
c.JSON(statusCode, gin.H{
"description": description,
"type": "new_api_error",
"code": code,
})
c.Abort()
common.LogError(c.Request.Context(), description)
common.LogError(c.Request.Context(), message)
}

View File

@ -3,8 +3,6 @@ package model
import (
"errors"
"fmt"
"github.com/samber/lo"
"gorm.io/gorm"
"one-api/common"
"strings"
)
@ -29,20 +27,8 @@ func GetGroupModels(group string) []string {
return models
}
func GetEnabledModels() []string {
var models []string
// Find distinct models
DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models)
return models
}
func GetAllEnableAbilities() []Ability {
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
var abilities []Ability
DB.Find(&abilities, "enabled = ?", true)
return abilities
}
func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
@ -50,60 +36,9 @@ func getPriority(group string, model string, retry int) (int, error) {
trueVal = "true"
}
var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil {
// 处理错误
return 0, err
}
if len(priorities) == 0 {
// 如果没有查询到优先级,则返回错误
return 0, errors.New("数据库一致性被破坏")
}
// 确定要使用的优先级
var priorityToUse int
if retry >= len(priorities) {
// 如果重试次数大于优先级数,则使用最小的优先级
priorityToUse = priorities[len(priorities)-1]
} else {
priorityToUse = priorities[retry]
}
return priorityToUse, nil
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
}
}
return channelQuery
}
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
var abilities []Ability
var err error = nil
channelQuery := getChannelQuery(group, model, retry)
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else {
@ -117,16 +52,21 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
// Randomly choose one
weightSum := uint(0)
for _, ability_ := range abilities {
weightSum += ability_.Weight + 10
weightSum += ability_.Weight
}
// Randomly choose one
weight := common.GetRandomInt(int(weightSum))
for _, ability_ := range abilities {
weight -= int(ability_.Weight) + 10
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break
if weightSum == 0 {
// All weight is 0, randomly choose one
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
} else {
// Randomly choose one
weight := common.GetRandomInt(int(weightSum))
for _, ability_ := range abilities {
weight -= int(ability_.Weight)
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break
}
}
}
} else {
@ -153,16 +93,7 @@ func (channel *Channel) AddAbilities() error {
abilities = append(abilities, ability)
}
}
if len(abilities) == 0 {
return nil
}
for _, chunk := range lo.Chunk(abilities, 50) {
err := DB.Create(&chunk).Error
if err != nil {
return err
}
}
return nil
return DB.Create(&abilities).Error
}
func (channel *Channel) DeleteAbilities() error {
@ -210,18 +141,13 @@ func FixAbility() (int, error) {
// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err
}
var channels []Channel
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
}
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
if err != nil {
return 0, err
}

View File

@ -25,6 +25,9 @@ var token2UserId = make(map[string]int)
var token2UserIdLock sync.RWMutex
func cacheSetToken(token *Token) error {
if !common.RedisEnabled {
return token.SelectUpdate()
}
jsonBytes, err := json.Marshal(token)
if err != nil {
return err
@ -87,7 +90,7 @@ func SyncTokenCache(frequency int) {
}
} else {
// 如果数据库中存在先检查redis
_, err = common.RedisGet(fmt.Sprintf("token:%s", key))
_, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
// 如果redis中不存在则跳过
continue
@ -165,11 +168,7 @@ func CacheUpdateUserQuota(id int) error {
if err != nil {
return err
}
return cacheSetUserQuota(id, quota)
}
func cacheSetUserQuota(id int, quota int) error {
err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
return err
}
@ -205,30 +204,6 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
if !common.RedisEnabled {
return IsLinuxDoEnabled(userId)
}
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if linuxDoEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set linuxdo enabled error: " + err.Error())
}
return linuxDoEnabled, err
}
var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex
@ -290,16 +265,14 @@ func SyncChannelCache(frequency int) {
}
}
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
} else if strings.HasPrefix(model, "g-") {
model = "g-*"
}
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model, retry)
return GetRandomSatisfiedChannel(group, model)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
@ -307,44 +280,35 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
uniquePriorities := make(map[int]bool)
for _, channel := range channels {
uniquePriorities[int(channel.GetPriority())] = true
}
var sortedUniquePriorities []int
for priority := range uniquePriorities {
sortedUniquePriorities = append(sortedUniquePriorities, priority)
}
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
if retry >= len(uniquePriorities) {
retry = len(uniquePriorities) - 1
}
targetPriority := int64(sortedUniquePriorities[retry])
// get the priority for the given retry number
var targetChannels []*Channel
for _, channel := range channels {
if channel.GetPriority() == targetPriority {
targetChannels = append(targetChannels, channel)
endIdx := len(channels)
// choose by priority
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
}
}
// 平滑系数
smoothingFactor := 10
// Calculate the total weight of all channels up to endIdx
totalWeight := 0
for _, channel := range targetChannels {
totalWeight += channel.GetWeight() + smoothingFactor
for _, channel := range channels[:endIdx] {
totalWeight += channel.GetWeight()
}
if totalWeight == 0 {
// If all weights are 0, select a channel randomly
return channels[rand.Intn(endIdx)], nil
}
// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight
for _, channel := range targetChannels {
randomWeight -= channel.GetWeight() + smoothingFactor
if randomWeight < 0 {
for _, channel := range channels[:endIdx] {
randomWeight -= channel.GetWeight()
if randomWeight <= 0 {
return channel, nil
}
}

View File

@ -1,18 +1,15 @@
package model
import (
"encoding/json"
"gorm.io/gorm"
"one-api/common"
"strings"
)
type Channel struct {
Id int `json:"id"`
Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null"`
Key string `json:"key" gorm:"not null;index"`
OpenAIOrganization *string `json:"openai_organization"`
TestModel *string `json:"test_model"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"`
@ -27,49 +24,8 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
//MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
}
func (channel *Channel) GetModels() []string {
if channel.Models == "" {
return []string{}
}
return strings.Split(strings.Trim(channel.Models, ","), ",")
}
func (channel *Channel) GetOtherInfo() map[string]interface{} {
otherInfo := make(map[string]interface{})
if channel.OtherInfo != "" {
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
if err != nil {
common.SysError("failed to unmarshal other info: " + err.Error())
}
}
return otherInfo
}
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
otherInfoBytes, err := json.Marshal(otherInfo)
if err != nil {
common.SysError("failed to marshal other info: " + err.Error())
return
}
channel.OtherInfo = string(otherInfoBytes)
}
func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
}
return *channel.AutoBan == 1
}
func (channel *Channel) Save() error {
return DB.Save(channel).Error
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
}
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
@ -87,46 +43,21 @@ func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Chan
return channels, err
}
func SearchChannels(keyword string, group string, model string) ([]*Channel, error) {
var channels []*Channel
func SearchChannels(keyword string, group string) (channels []*Channel, err error) {
keyCol := "`key`"
groupCol := "`group`"
modelsCol := "`models`"
// 如果是 PostgreSQL使用双引号
if common.UsingPostgreSQL {
keyCol = `"key"`
groupCol = `"group"`
modelsCol = `"models"`
}
// 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
// 构造WHERE子句
var whereClause string
var args []interface{}
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
if group != "" {
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%", "%,"+group+",%")
err = DB.Omit("key").Where("(id = ? or name LIKE ? or "+keyCol+" = ?) and "+groupCol+" LIKE ?", common.String2Int(keyword), keyword+"%", keyword, "%"+group+"%").Find(&channels).Error
} else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
}
// 执行查询
err := baseQuery.Where(whereClause, args...).Order("priority desc").Find(&channels).Error
if err != nil {
return nil, err
}
return channels, nil
return channels, err
}
func GetChannelById(id int, selectAll bool) (*Channel, error) {
@ -203,13 +134,6 @@ func (channel *Channel) GetModelMapping() string {
return *channel.ModelMapping
}
func (channel *Channel) GetStatusCodeMapping() string {
if channel.StatusCodeMapping == nil {
return ""
}
return *channel.StatusCodeMapping
}
func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error
@ -261,31 +185,15 @@ func (channel *Channel) Delete() error {
return err
}
func UpdateChannelStatusById(id int, status int, reason string) {
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {
common.SysError("failed to update ability status: " + err.Error())
}
channel, err := GetChannelById(id, true)
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
// find channel by id error, directly update status
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
}
} else {
// find channel by id success, update status and other info
info := channel.GetOtherInfo()
info["status_reason"] = reason
info["status_time"] = common.GetTimestamp()
channel.SetOtherInfo(info)
channel.Status = status
err = channel.Save()
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
}
common.SysError("failed to update channel status: " + err.Error())
}
}
func UpdateChannelUsedQuota(id int, quota int) {

View File

@ -3,12 +3,9 @@ package model
import (
"context"
"fmt"
"gorm.io/gorm"
"one-api/common"
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
type Log struct {
@ -27,7 +24,6 @@ type Log struct {
IsStream bool `json:"is_stream" gorm:"default:false"`
ChannelId int `json:"channel" gorm:"index"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Other string `json:"other"`
}
const (
@ -39,7 +35,7 @@ const (
)
func GetLogByKey(key string) (logs []*Log, err error) {
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.Split(key, "-")[1]).Find(&logs).Error
return logs, err
}
@ -55,19 +51,18 @@ func RecordLog(userId int, logType int, content string) {
Type: logType,
Content: content,
}
err := LOG_DB.Create(log).Error
err := DB.Create(log).Error
if err != nil {
common.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool, other map[string]interface{}) {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
username, _ := CacheGetUsername(userId)
otherStr := common.MapToJsonStr(other)
log := &Log{
UserId: userId,
Username: username,
@ -83,28 +78,27 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
TokenId: tokenId,
UseTime: useTimeSeconds,
IsStream: isStream,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
err := DB.Create(log).Error
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
common.SafeGoroutine(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp())
})
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB
tx = DB
} else {
tx = LOG_DB.Where("type = ?", logType)
tx = DB.Where("type = ?", logType)
}
if modelName != "" {
tx = tx.Where("model_name like ?", modelName)
tx = tx.Where("model_name = ?", modelName)
}
if username != "" {
tx = tx.Where("username = ?", username)
@ -121,26 +115,19 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
return nil, 0, err
}
return logs, total, err
return logs, err
}
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB.Where("user_id = ?", userId)
tx = DB.Where("user_id = ?", userId)
} else {
tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
tx = DB.Where("user_id = ? and type = ?", userId, logType)
}
if modelName != "" {
tx = tx.Where("model_name like ?", modelName)
tx = tx.Where("model_name = ?", modelName)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
@ -151,30 +138,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
for i := range logs {
var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other)
if otherMap != nil {
// delete admin
delete(otherMap, "admin_info")
}
logs[i].Other = common.MapToJsonStr(otherMap)
}
return logs, total, err
return logs, err
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err
}
@ -185,18 +159,12 @@ type Stat struct {
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" {
tx = tx.Where("username = ?", username)
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
}
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
@ -205,29 +173,17 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx = tx.Where("created_at <= ?", endTimestamp)
}
if modelName != "" {
tx = tx.Where("model_name like ?", modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName)
tx = tx.Where("model_name = ?", modelName)
}
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
}
tx = tx.Where("type = ?", LogTypeConsume)
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
// 只统计最近60秒的rpm和tpm
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
// 执行查询
tx.Scan(&stat)
rpmTpmQuery.Scan(&stat)
tx.Where("type = ?", LogTypeConsume).Scan(&stat)
return stat
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@ -247,25 +203,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
return token
}
func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) {
var total int64 = 0
for {
if nil != ctx.Err() {
return total, ctx.Err()
}
result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{})
if nil != result.Error {
return total, result.Error
}
total += result.RowsAffected
if result.RowsAffected < int64(limit) {
break
}
}
return total, nil
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}

View File

@ -5,18 +5,14 @@ import (
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"log"
"one-api/common"
"os"
"strings"
"sync"
"time"
)
var DB *gorm.DB
var LOG_DB *gorm.DB
func createRootAccountIfNeed() error {
var user User
//if user.Status != common.UserStatusEnabled {
@ -32,7 +28,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: nil,
AccessToken: common.GetUUID(),
Quota: 100000000,
}
DB.Create(&rootUser)
@ -40,9 +36,9 @@ func createRootAccountIfNeed() error {
return nil
}
func chooseDB(envName string) (*gorm.DB, error) {
dsn := os.Getenv(envName)
if dsn != "" {
func chooseDB() (*gorm.DB, error) {
if os.Getenv("SQL_DSN") != "" {
dsn := os.Getenv("SQL_DSN")
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
@ -54,13 +50,6 @@ func chooseDB(envName string) (*gorm.DB, error) {
PrepareStmt: true, // precompile SQL
})
}
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL
common.SysLog("using MySQL as database")
// check parseTime
@ -71,7 +60,6 @@ func chooseDB(envName string) (*gorm.DB, error) {
dsn += "?parseTime=true"
}
}
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@ -85,7 +73,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
}
func InitDB() (err error) {
db, err := chooseDB("SQL_DSN")
db, err := chooseDB()
if err == nil {
if common.DebugEnabled {
db = db.Debug()
@ -95,58 +83,56 @@ func InitDB() (err error) {
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil
}
//if common.UsingMySQL {
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//}
common.SysLog("database migration started")
err = migrateDB()
return err
} else {
common.FatalLog(err)
}
return err
}
func InitLogDB() (err error) {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
db, err := chooseDB("LOG_SQL_DSN")
if err == nil {
if common.DebugEnabled {
db = db.Debug()
}
LOG_DB = db
sqlDB, err := LOG_DB.DB()
err = db.AutoMigrate(&Channel{})
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil
err = db.AutoMigrate(&Token{})
if err != nil {
return err
}
//if common.UsingMySQL {
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//}
common.SysLog("database migration started")
err = migrateLOGDB()
err = db.AutoMigrate(&User{})
if err != nil {
return err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return err
}
err = db.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = db.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = db.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
} else {
common.FatalLog(err)
@ -154,109 +140,11 @@ func InitLogDB() (err error) {
return err
}
func migrateDB() error {
err := DB.AutoMigrate(&Channel{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Token{})
if err != nil {
return err
}
err = DB.AutoMigrate(&User{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Option{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Log{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = DB.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = DB.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Task{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func closeDB(db *gorm.DB) error {
sqlDB, err := db.DB()
func CloseDB() error {
sqlDB, err := DB.DB()
if err != nil {
return err
}
err = sqlDB.Close()
return err
}
func CloseDB() error {
if LOG_DB != DB {
err := closeDB(LOG_DB)
if err != nil {
return err
}
}
return closeDB(DB)
}
var (
lastPingTime time.Time
pingMutex sync.Mutex
)
func PingDB() error {
pingMutex.Lock()
defer pingMutex.Unlock()
if time.Since(lastPingTime) < time.Second*10 {
return nil
}
sqlDB, err := DB.DB()
if err != nil {
log.Printf("Error getting sql.DB from GORM: %v", err)
return err
}
err = sqlDB.Ping()
if err != nil {
log.Printf("Error pinging DB: %v", err)
return err
}
lastPingTime = time.Now()
common.SysLog("Database pinged successfully")
return nil
}

View File

@ -4,23 +4,21 @@ type Midjourney struct {
Id int `json:"id"`
Code int `json:"code"`
UserId int `json:"user_id" gorm:"index"`
Action string `json:"action" gorm:"type:varchar(40);index"`
Action string `json:"action"`
MjId string `json:"mj_id" gorm:"index"`
Prompt string `json:"prompt"`
PromptEn string `json:"prompt_en"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submit_time" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
ImageUrl string `json:"image_url"`
Status string `json:"status" gorm:"type:varchar(20);index"`
Progress string `json:"progress" gorm:"type:varchar(30);index"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
Buttons string `json:"buttons"`
Properties string `json:"properties"`
}
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段

View File

@ -2,7 +2,6 @@ package model
import (
"one-api/common"
"one-api/constant"
"strconv"
"strings"
"time"
@ -31,30 +30,24 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
common.OptionMap["SMTPAccount"] = ""
common.OptionMap["SMTPToken"] = ""
common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled)
common.OptionMap["Notice"] = ""
common.OptionMap["About"] = ""
common.OptionMap["HomePageContent"] = ""
@ -62,22 +55,13 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = ""
common.OptionMap["OutProxyUrl"] = ""
common.OptionMap["StripeApiSecret"] = common.StripeApiSecret
common.OptionMap["StripeWebhookSecret"] = common.StripeWebhookSecret
common.OptionMap["StripePriceId"] = common.StripePriceId
common.OptionMap["PaymentEnabled"] = strconv.FormatBool(common.PaymentEnabled)
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
common.OptionMap["PayAddress"] = ""
common.OptionMap["EpayId"] = ""
common.OptionMap["EpayKey"] = ""
common.OptionMap["Price"] = strconv.FormatFloat(common.Price, 'f', -1, 64)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = constant.Chats2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["LinuxDoClientId"] = ""
common.OptionMap["LinuxDoClientSecret"] = ""
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
common.OptionMap["TelegramBotToken"] = ""
common.OptionMap["TelegramBotName"] = ""
common.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["WeChatServerToken"] = ""
common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
@ -91,8 +75,6 @@ func InitOptionMap() {
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["ChatLink2"] = common.ChatLink2
@ -100,18 +82,6 @@ func InitOptionMap() {
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@ -168,7 +138,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
if strings.HasSuffix(key, "Enabled") {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
@ -179,22 +149,14 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue
case "LinuxDoOAuthEnabled":
common.LinuxDoOAuthEnabled = boolValue
case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue
case "TelegramOAuthEnabled":
common.TelegramOAuthEnabled = boolValue
case "TurnstileCheckEnabled":
common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
common.RegisterEnabled = boolValue
case "UserSelfDeletionEnabled":
common.UserSelfDeletionEnabled = boolValue
case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue
case "EmailAliasRestrictionEnabled":
common.EmailAliasRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
@ -207,32 +169,8 @@ func updateOptionMap(key string, value string) (err error) {
common.DisplayTokenStatEnabled = boolValue
case "DrawingEnabled":
common.DrawingEnabled = boolValue
case "TaskEnabled":
common.TaskEnabled = boolValue
case "DataExportEnabled":
common.DataExportEnabled = boolValue
case "DefaultCollapseSidebar":
common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled":
constant.MjNotifyEnabled = boolValue
case "MjAccountFilterEnabled":
constant.MjAccountFilterEnabled = boolValue
case "MjModeClearEnabled":
constant.MjModeClearEnabled = boolValue
case "MjForwardUrlEnabled":
constant.MjForwardUrlEnabled = boolValue
case "MjActionCheckSuccessEnabled":
constant.MjActionCheckSuccessEnabled = boolValue
case "CheckSensitiveEnabled":
constant.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
constant.CheckSensitiveOnPromptEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled":
constant.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled":
common.SMTPSSLEnabled = boolValue
}
}
switch key {
@ -251,34 +189,20 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPToken = value
case "ServerAddress":
common.ServerAddress = value
case "OutProxyUrl":
common.OutProxyUrl = value
case "Chats":
err = constant.UpdateChatsByJsonString(value)
case "StripeApiSecret":
common.StripeApiSecret = value
case "StripeWebhookSecret":
common.StripeWebhookSecret = value
case "StripePriceId":
common.StripePriceId = value
case "PaymentEnabled":
common.PaymentEnabled, _ = strconv.ParseBool(value)
case "StripeUnitPrice":
common.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
case "MinTopUp":
common.MinTopUp, _ = strconv.Atoi(value)
case "PayAddress":
common.PayAddress = value
case "EpayId":
common.EpayId = value
case "EpayKey":
common.EpayKey = value
case "Price":
common.Price, _ = strconv.ParseFloat(value, 64)
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":
common.GitHubClientId = value
case "GitHubClientSecret":
common.GitHubClientSecret = value
case "LinuxDoClientId":
common.LinuxDoClientId = value
case "LinuxDoClientSecret":
common.LinuxDoClientSecret = value
case "LinuxDoMinLevel":
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
case "Footer":
common.Footer = value
case "SystemName":
@ -291,10 +215,6 @@ func updateOptionMap(key string, value string) (err error) {
common.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL":
common.WeChatAccountQRCodeImageURL = value
case "TelegramBotToken":
common.TelegramBotToken = value
case "TelegramBotName":
common.TelegramBotName = value
case "TurnstileSiteKey":
common.TurnstileSiteKey = value
case "TurnstileSecretKey":
@ -319,10 +239,6 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = common.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
err = common.UpdateModelPriceByJSONString(value)
case "TopUpLink":
@ -335,10 +251,6 @@ func updateOptionMap(key string, value string) (err error) {
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords":
constant.SensitiveWordsFromString(value)
case "StreamCacheQueueLength":
constant.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
return err
}

View File

@ -1,79 +0,0 @@
package model
import (
"one-api/common"
"sync"
"time"
)
type Pricing struct {
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_groups,omitempty"`
}
var (
pricingMap []Pricing
lastGetPricingTime time.Time
updatePricingLock sync.Mutex
)
func GetPricing() []Pricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing()
}
//if group != "" {
// userPricingMap := make([]Pricing, 0)
// models := GetGroupModels(group)
// for _, pricing := range pricingMap {
// if !common.StringsContains(models, pricing.ModelName) {
// pricing.Available = false
// }
// userPricingMap = append(userPricingMap, pricing)
// }
// return userPricingMap
//}
return pricingMap
}
func updatePricing() {
//modelRatios := common.GetModelRatios()
enableAbilities := GetAllEnableAbilities()
modelGroupsMap := make(map[string][]string)
for _, ability := range enableAbilities {
groups := modelGroupsMap[ability.Model]
if groups == nil {
groups = make([]string, 0)
}
if !common.StringsContains(groups, ability.Group) {
groups = append(groups, ability.Group)
}
modelGroupsMap[ability.Model] = groups
}
pricingMap = make([]Pricing, 0)
for model, groups := range modelGroupsMap {
pricing := Pricing{
ModelName: model,
EnableGroup: groups,
}
modelPrice, findPrice := common.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
pricing.ModelRatio = common.GetModelRatio(model)
pricing.CompletionRatio = common.GetCompletionRatio(model)
pricing.QuotaType = 0
}
pricingMap = append(pricingMap, pricing)
}
lastGetPricingTime = time.Now()
}

View File

@ -8,17 +8,16 @@ import (
)
type Redemption struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Quota int `json:"quota" gorm:"default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request
UsedUserId int `json:"used_user_id"`
DeletedAt gorm.DeletedAt `gorm:"index"`
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Quota int `json:"quota" gorm:"default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request
UsedUserId int `json:"used_user_id"`
}
func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) {
@ -56,7 +55,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if common.UsingPostgreSQL {
keyCol = `"key"`
}
common.RandomSleep()
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil {
@ -78,7 +77,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return 0, errors.New("兑换失败," + err.Error())
}
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
return redemption.Quota, nil
}

View File

@ -1,304 +0,0 @@
package model
import (
"database/sql/driver"
"encoding/json"
"one-api/constant"
commonRelay "one-api/relay/common"
"time"
)
type TaskStatus string
const (
TaskStatusNotStart TaskStatus = "NOT_START"
TaskStatusSubmitted = "SUBMITTED"
TaskStatusQueued = "QUEUED"
TaskStatusInProgress = "IN_PROGRESS"
TaskStatusFailure = "FAILURE"
TaskStatusSuccess = "SUCCESS"
TaskStatusUnknown = "UNKNOWN"
)
type Task struct {
ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
CreatedAt int64 `json:"created_at" gorm:"index"`
UpdatedAt int64 `json:"updated_at"`
TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id不一定有/ song id\ Task id
Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
UserId int `json:"user_id" gorm:"index"`
ChannelId int `json:"channel_id" gorm:"index"`
Quota int `json:"quota"`
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态
FailReason string `json:"fail_reason"`
SubmitTime int64 `json:"submit_time" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
Progress string `json:"progress" gorm:"type:varchar(20);index"`
Properties Properties `json:"properties" gorm:"type:json"`
Data json.RawMessage `json:"data" gorm:"type:json"`
}
func (t *Task) SetData(data any) {
b, _ := json.Marshal(data)
t.Data = json.RawMessage(b)
}
func (t *Task) GetData(v any) error {
err := json.Unmarshal(t.Data, &v)
return err
}
type Properties struct {
Input string `json:"input"`
}
func (m *Properties) Scan(val interface{}) error {
bytesValue, _ := val.([]byte)
return json.Unmarshal(bytesValue, m)
}
func (m Properties) Value() (driver.Value, error) {
return json.Marshal(m)
}
// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
type SyncTaskQueryParams struct {
Platform constant.TaskPlatform
ChannelID string
TaskID string
UserID string
Action string
Status string
StartTimestamp int64
EndTimestamp int64
UserIDs []int
}
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
t := &Task{
UserId: relayInfo.UserId,
SubmitTime: time.Now().Unix(),
Status: TaskStatusNotStart,
Progress: "0%",
ChannelId: relayInfo.ChannelId,
Platform: platform,
}
return t
}
func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
var tasks []*Task
var err error
// 初始化查询构建器
query := DB.Where("user_id = ?", userId)
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.Platform != "" {
query = query.Where("platform = ?", queryParams.Platform)
}
if queryParams.StartTimestamp != 0 {
// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
// 获取数据
err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
var tasks []*Task
var err error
// 初始化查询构建器
query := DB
// 添加过滤条件
if queryParams.ChannelID != "" {
query = query.Where("channel_id = ?", queryParams.ChannelID)
}
if queryParams.Platform != "" {
query = query.Where("platform = ?", queryParams.Platform)
}
if queryParams.UserID != "" {
query = query.Where("user_id = ?", queryParams.UserID)
}
if len(queryParams.UserIDs) != 0 {
query = query.Where("user_id in (?)", queryParams.UserIDs)
}
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.StartTimestamp != 0 {
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
// 获取数据
err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func GetAllUnFinishSyncTasks(limit int) []*Task {
var tasks []*Task
var err error
// get all tasks progress is not 100%
err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
if taskId == "" {
return nil, false, nil
}
var task *Task
var err error
err = DB.Where("task_id = ?", taskId).First(&task).Error
exist, err := RecordExist(err)
if err != nil {
return nil, false, err
}
return task, exist, err
}
func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
if taskId == "" {
return nil, false, nil
}
var task *Task
var err error
err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
First(&task).Error
exist, err := RecordExist(err)
if err != nil {
return nil, false, err
}
return task, exist, err
}
func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
if len(taskIds) == 0 {
return nil, nil
}
var task []*Task
var err error
err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
Find(&task).Error
if err != nil {
return nil, err
}
return task, nil
}
func TaskUpdateProgress(id int64, progress string) error {
return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
}
func (Task *Task) Insert() error {
var err error
err = DB.Create(Task).Error
return err
}
func (Task *Task) Update() error {
var err error
err = DB.Save(Task).Error
return err
}
func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
if len(TaskIds) == 0 {
return nil
}
return DB.Model(&Task{}).
Where("task_id in (?)", TaskIds).
Updates(params).Error
}
func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
if len(taskIDs) == 0 {
return nil
}
return DB.Model(&Task{}).
Where("id in (?)", taskIDs).
Updates(params).Error
}
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
if len(ids) == 0 {
return nil
}
return DB.Model(&Task{}).
Where("id in (?)", ids).
Updates(params).Error
}
type TaskQuotaUsage struct {
Mode string `json:"mode"`
Count float64 `json:"count"`
}
func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
query := DB.Model(Task{})
// 添加过滤条件
if queryParams.ChannelID != "" {
query = query.Where("channel_id = ?", queryParams.ChannelID)
}
if queryParams.UserID != "" {
query = query.Where("user_id = ?", queryParams.UserID)
}
if len(queryParams.UserIDs) != 0 {
query = query.Where("user_id in (?)", queryParams.UserIDs)
}
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.StartTimestamp != 0 {
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
return stat, err
}

View File

@ -5,50 +5,24 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
relaycommon "one-api/relay/common"
"strconv"
"strings"
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Group string `json:"group" gorm:"default:''"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (token *Token) GetIpLimitsMap() map[string]any {
// delete empty spaces
//split with \n
ipLimitsMap := make(map[string]any)
if token.AllowIps == nil {
return ipLimitsMap
}
cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
if cleanIps == "" {
return ipLimitsMap
}
ips := strings.Split(cleanIps, "\n")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
ip = strings.ReplaceAll(ip, ",", "")
if common.IsIP(ip) {
ipLimitsMap[ip] = true
}
}
return ipLimitsMap
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
@ -73,14 +47,12 @@ func ValidateUserToken(key string) (token *Token, err error) {
token, err = CacheGetTokenByKey(key)
if err == nil {
if token.Status == common.TokenStatusExhausted {
keyPrefix := key[:3]
keySuffix := key[len(key)-3:]
return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
return nil, errors.New("该令牌额度已用尽 token.Status == common.TokenStatusExhausted " + key)
} else if token.Status == common.TokenStatusExpired {
return token, errors.New("该令牌已过期")
return nil, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled {
return token, errors.New("该令牌状态不可用")
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if !common.RedisEnabled {
@ -90,7 +62,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysError("failed to update token status" + err.Error())
}
}
return token, errors.New("该令牌已过期")
return nil, errors.New("该令牌已过期")
}
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
@ -101,9 +73,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysError("failed to update token status" + err.Error())
}
}
keyPrefix := key[:3]
keySuffix := key[len(key)-3:]
return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
return nil, errors.New(fmt.Sprintf("%s 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", token.Key, token.RemainQuota))
}
return token, nil
}
@ -127,11 +97,6 @@ func GetTokenById(id int) (*Token, error) {
token := Token{Id: id}
var err error = nil
err = DB.First(&token, "id = ?", id).Error
if err != nil {
if common.RedisEnabled {
go cacheSetToken(&token)
}
}
return &token, err
}
@ -154,8 +119,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error
return err
}
@ -257,52 +221,51 @@ func decreaseTokenQuota(id int, quota int) (err error) {
return err
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) {
func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
if quota < 0 {
return 0, errors.New("quota 不能为负数!")
}
if !relayInfo.IsPlayground {
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return 0, err
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
token, err := GetTokenById(tokenId)
if err != nil {
return 0, err
}
userQuota, err = GetUserQuota(relayInfo.UserId)
if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
userQuota, err = GetUserQuota(token.UserId)
if err != nil {
return 0, err
}
if userQuota < quota {
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
}
if !relayInfo.IsPlayground {
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
if !token.UnlimitedQuota {
err = DecreaseTokenQuota(tokenId, quota)
if err != nil {
return 0, err
}
}
err = DecreaseUserQuota(relayInfo.UserId, quota)
err = DecreaseUserQuota(token.UserId, quota)
return userQuota - quota, err
}
func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 {
err = DecreaseUserQuota(relayInfo.UserId, quota)
err = DecreaseUserQuota(token.UserId, quota)
} else {
err = IncreaseUserQuota(relayInfo.UserId, -quota)
err = IncreaseUserQuota(token.UserId, -quota)
}
if err != nil {
return err
}
if !relayInfo.IsPlayground {
if !token.UnlimitedQuota {
if quota > 0 {
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
err = DecreaseTokenQuota(tokenId, quota)
} else {
err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
err = IncreaseTokenQuota(tokenId, -quota)
}
if err != nil {
return err
@ -315,7 +278,7 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota {
go func() {
email, err := GetUserEmail(relayInfo.UserId)
email, err := GetUserEmail(token.UserId)
if err != nil {
common.SysError("failed to fetch user email: " + err.Error())
}

Some files were not shown because too many files have changed in this diff Show More