Compare commits

..

No commits in common. "main" and "v0.2.1.0-alpha.5" have entirely different histories.

271 changed files with 7015 additions and 30132 deletions

View File

@ -1,5 +1,5 @@
blank_issues_enabled: false blank_issues_enabled: false
contact_links: contact_links:
- name: 交流社区 - name: 项目群聊
url: https://linux.do url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
about: 项目交流社区 about: QQ 群629454374

View File

@ -1,6 +1,10 @@
name: Publish Docker image (amd64) name: Publish Docker image (amd64)
on: on:
push:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
name: name:
@ -39,7 +43,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
pengzhile/new-api calciumion/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images

View File

@ -49,7 +49,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
pengzhile/new-api calciumion/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images

1
.gitignore vendored
View File

@ -6,4 +6,3 @@ upload
build build
*.db-journal *.db-journal
logs logs
web/dist

View File

@ -1,13 +1,13 @@
FROM oven/bun:latest AS builder FROM node:16 as builder
WORKDIR /build WORKDIR /build
COPY web/package.json . COPY web/package.json .
RUN bun install RUN npm install
COPY ./web . COPY ./web .
COPY ./VERSION . COPY ./VERSION .
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang:1.21 AS builder2 FROM golang AS builder2
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=1 \ CGO_ENABLED=1 \
@ -17,7 +17,7 @@ WORKDIR /build
ADD go.mod go.sum ./ ADD go.mod go.sum ./
RUN go mod download RUN go mod download
COPY . . COPY . .
COPY --from=builder /build/dist ./web/dist COPY --from=builder /build/build ./web/build
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine FROM alpine
@ -25,7 +25,7 @@ FROM alpine
RUN apk update \ RUN apk update \
&& apk upgrade \ && apk upgrade \
&& apk add --no-cache ca-certificates tzdata \ && apk add --no-cache ca-certificates tzdata \
&& update-ca-certificates && update-ca-certificates 2>/dev/null || true
COPY --from=builder2 /build/one-api / COPY --from=builder2 /build/one-api /
EXPOSE 3000 EXPOSE 3000

214
LICENSE
View File

@ -1,201 +1,21 @@
Apache License MIT License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION Copyright (c) 2024 Calcium-Ion
1. Definitions. Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
"License" shall mean the terms and conditions for use, reproduction, The above copyright notice and this permission notice shall be included in all
and distribution as defined by Sections 1 through 9 of this document. copies or substantial portions of the Software.
"Licensor" shall mean the copyright owner or entity authorized by THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
the copyright owner that is granting the License. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
"Legal Entity" shall mean the union of the acting entity and all AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
other entities that control, are controlled by, or are under common LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
control with that entity. For the purposes of this definition, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
"control" means (i) the power, direct or indirect, to cause the SOFTWARE.
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -2,21 +2,6 @@
**简介**:Midjourney Proxy API文档 **简介**:Midjourney Proxy API文档
## 接口列表
支持的接口如下:
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
+ [x] /mj/submit/action 仅midjourney-proxy-plus支持下同
+ [x] /mj/submit/modal
+ [x] /mj/submit/shorten
+ [x] /mj/task/{id}/image-seed
+ [x] /mj/insight-face/swap InsightFace
## 模型列表 ## 模型列表
### midjourney-proxy支持 ### midjourney-proxy支持
@ -72,11 +57,11 @@
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**如果是plus版本选择**Midjourney Proxy Plus** 2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**如果是plus版本选择**Midjourney Proxy Plus**
,模型请参考上方模型列表 ,模型请参考上方模型列表
3. **代理**填写midjourney-proxy部署的地址例如http://localhost:8080 3. 地址填写midjourney-proxy部署的地址例如http://localhost:8080
4. 密钥填写midjourney-proxy的密钥如果没有设置密钥可以随便填 4. 密钥填写midjourney-proxy的密钥如果没有设置密钥可以随便填
### 对接上游new api ### 对接上游new api
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表 1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
2. **代理**填写上游new api的地址例如http://localhost:3000 2. 地址填写上游new api的地址例如http://localhost:3000
3. 密钥填写上游new api的密钥 3. 密钥填写上游new api的密钥

105
README.md
View File

@ -1,39 +1,40 @@
<div align="center">
![new-api](/web/public/logo.png)
# New API # New API
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
> [!NOTE] > [!NOTE]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发 > 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
> [!IMPORTANT]
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
> [!WARNING]
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> [!TIP] > [!NOTE]
> 最新版Docker镜像`calciumion/new-api:latest` > 最新版Docker镜像 calciumion/new-api:latest
> 默认账号root 密码123456 > 更新指令 docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> 更新指令:
> ```
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
## 主要变更 ## 主要变更
此分叉版本的主要变更如下: 此分叉版本的主要变更如下:
1. 全新的UI界面部分界面还待更新 1. 全新的UI界面部分界面还待更新
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md) 2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md),支持的接口如下:
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
+ [x] /mj/submit/action 仅midjourney-proxy-plus支持下同
+ [x] /mj/submit/modal
+ [x] /mj/submit/shorten
+ [x] /mj/task/{id}/image-seed
+ [x] /mj/insight-face/swap InsightFace
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口: 3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付 + [x] 易支付
4. 支持用key查询使用额度: 4. 支持用key查询使用额度:
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用 + 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
5. 渠道显示已使用额度,支持指定组织访问 5. 渠道显示已使用额度,支持指定组织访问
6. 分页支持选择每页显示数量 6. 分页支持选择每页显示数量
7. 兼容原版One API的数据库可直接使用原版数据库one-api.db 7. 兼容原版One API的数据库可直接使用原版数据库one-api.db
@ -46,39 +47,18 @@
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain 2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
3. 选择你的bot然后输入http(s)://你的网站地址/login 3. 选择你的bot然后输入http(s)://你的网站地址/login
4. Telegram Bot 名称是bot username 去掉@后的字符串 4. Telegram Bot 名称是bot username 去掉@后的字符串
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
14. 支持Rerank模型目前仅兼容Cohere和Jina可接入Dify[对接文档](Rerank.md)
## 模型支持 ## 模型支持
此版本额外支持以下模型: 此版本额外支持以下模型:
1. 第三方模型 **gps** gpt-4-gizmo-*, g-* 1. 第三方模型 **gps** gpt-4-gizmo-*
2. 智谱glm-4vglm-4v识图 2. 智谱glm-4vglm-4v识图
3. Anthropic Claude 3 3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改 4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md) 5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
6. [零一万物](https://platform.lingyiwanwu.com/)
7. 自定义渠道,支持填入完整调用地址
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md)
10. Dify
11. Vertex AI目前兼容ClaudeGeminiLlama3.1
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。 您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 比原版One API多出的配置
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
- `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL``STRICT`,默认为 `NONE`
## 部署 ## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
- 远程数据库MySQL 版本 >= 5.7.8PgSQL 版本 >= 9.6
### 基于 Docker 进行部署 ### 基于 Docker 进行部署
```shell ```shell
# 使用 SQLite 的部署命令: # 使用 SQLite 的部署命令:
@ -97,40 +77,25 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 注意数据库要开启远程访问并且只允许服务器IP访问 # 注意数据库要开启远程访问并且只允许服务器IP访问
``` ```
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
### 为什么有的时候没有重试
这些错误码不会重试400504524
### 我想让400也重试
在`渠道->编辑`中,将`状态码复写`改为
```json
{
"400": "500"
}
```
可以实现400错误转为500错误从而重试
## Midjourney接口设置文档 ## Midjourney接口设置文档
[对接文档](Midjourney.md) [对接文档](Midjourney.md)
## Suno接口设置文档 ## 交流群
[对接文档](Suno.md) <img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
## 界面截图 ## 界面截图
![796df8d287b7b7bd7853b2497e7df511](https://github.com/user-attachments/assets/255b5e97-2d3a-4434-b4fa-e922ad88ff5a)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/d1ac216e-0804-4105-9fdc-66b35022d861)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/f4f40ed4-8ccb-43d7-a580-90677827646d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/90d7d763-6a77-4b36-9f76-2bb30f18583d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/e414228a-3c35-429a-b298-6451d76d9032)
夜间模式 夜间模式
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/5b3228e8-2556-44f7-97d6-4f8d8ee6effa)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## 相关项目 ## 相关项目

View File

@ -1,62 +0,0 @@
# Rerank API文档
**简介**:Rerank API文档
## 接入Dify
模型供应商选择Jina按要求填写模型信息即可接入Dify。
## 请求方式
Post: /v1/rerank
Request:
```json
{
"model": "rerank-multilingual-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
]
}
```
Response:
```json
{
"results": [
{
"document": {
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
},
"index": 2,
"relevance_score": 0.9999702
},
{
"document": {
"text": "Carson City is the capital city of the American state of Nevada."
},
"index": 0,
"relevance_score": 0.67800725
},
{
"document": {
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
},
"index": 3,
"relevance_score": 0.02800752
}
],
"usage": {
"prompt_tokens": 158,
"completion_tokens": 0,
"total_tokens": 158
}
}
```

44
Suno.md
View File

@ -1,44 +0,0 @@
# Suno API文档
**简介**:Suno API文档
## 接口列表
支持的接口如下:
+ [x] /suno/submit/music
+ [x] /suno/submit/lyrics
+ [x] /suno/fetch
+ [x] /suno/fetch/:id
## 模型列表
### Suno API支持
- suno_music (自定义模式、灵感模式、续写)
- suno_lyrics (生成歌词)
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
```json
{
"suno_music": 0.3,
"suno_lyrics": 0.01
}
```
## 渠道设置
### 对接 Suno API
1.
部署 Suno API并配置好suno账号等强烈建议设置密钥[项目地址](https://github.com/Suno-API/Suno-API)
2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
,模型请参考上方模型列表
3. **代理**填写 Suno API 部署的地址例如http://localhost:8080
4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
### 对接上游new api
1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
2. **代理**填写上游new api的地址例如http://localhost:3000
3. 密钥填写上游new api的密钥

View File

@ -11,18 +11,17 @@ import (
// Pay Settings // Pay Settings
var StripeApiSecret = "" var PayAddress = ""
var StripeWebhookSecret = "" var CustomCallbackAddress = ""
var StripePriceId = "" var EpayId = ""
var PaymentEnabled = false var EpayKey = ""
var StripeUnitPrice = 8.0 var Price = 7.3
var MinTopUp = 5 var MinTopUp = 1
var StartTime = time.Now().Unix() // unit: second var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API" var SystemName = "New API"
var ServerAddress = "http://localhost:3000" var ServerAddress = "http://localhost:3000"
var OutProxyUrl = ""
var Footer = "" var Footer = ""
var Logo = "" var Logo = ""
var TopUpLink = "" var TopUpLink = ""
@ -32,7 +31,6 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true var DisplayTokenStatEnabled = true
var DrawingEnabled = true var DrawingEnabled = true
var TaskEnabled = true
var DataExportEnabled = true var DataExportEnabled = true
var DataExportInterval = 5 // unit: minute var DataExportInterval = 5 // unit: minute
var DataExportDefaultTime = "hour" // unit: minute var DataExportDefaultTime = "hour" // unit: minute
@ -52,15 +50,12 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false var GitHubOAuthEnabled = false
var LinuxDoOAuthEnabled = false
var WeChatAuthEnabled = false var WeChatAuthEnabled = false
var TelegramOAuthEnabled = false var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false var TurnstileCheckEnabled = false
var RegisterEnabled = true var RegisterEnabled = true
var UserSelfDeletionEnabled = false
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制 var EmailDomainRestrictionEnabled = false
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
var EmailDomainWhitelist = []string{ var EmailDomainWhitelist = []string{
"gmail.com", "gmail.com",
"163.com", "163.com",
@ -80,7 +75,6 @@ var LogConsumeEnabled = true
var SMTPServer = "" var SMTPServer = ""
var SMTPPort = 587 var SMTPPort = 587
var SMTPSSLEnabled = false
var SMTPAccount = "" var SMTPAccount = ""
var SMTPFrom = "" var SMTPFrom = ""
var SMTPToken = "" var SMTPToken = ""
@ -88,10 +82,6 @@ var SMTPToken = ""
var GitHubClientId = "" var GitHubClientId = ""
var GitHubClientSecret = "" var GitHubClientSecret = ""
var LinuxDoClientId = ""
var LinuxDoClientSecret = ""
var LinuxDoMinLevel = 0
var WeChatServerAddress = "" var WeChatServerAddress = ""
var WeChatServerToken = "" var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = "" var WeChatAccountQRCodeImageURL = ""
@ -120,17 +110,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false var BatchUpdateEnabled = false
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5) var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
const ( const (
RequestIdKey = "X-Oneapi-Request-Id" RequestIdKey = "X-Oneapi-Request-Id"
@ -143,10 +130,6 @@ const (
RoleRootUser = 100 RoleRootUser = 100
) )
func IsValidateRole(role int) bool {
return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
}
var ( var (
FileUploadPermission = RoleGuestUser FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser FileDownloadPermission = RoleGuestUser
@ -157,10 +140,10 @@ var (
// All duration's unit is seconds // All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration // Shouldn't larger then RateLimitKeyExpirationDuration
var ( var (
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60 GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60 GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10 UploadRateLimitNum = 10
@ -200,12 +183,6 @@ const (
ChannelStatusAutoDisabled = 3 ChannelStatusAutoDisabled = 3
) )
const (
TopUpStatusPending = "pending"
TopUpStatusSuccess = "success"
TopUpStatusExpired = "expired"
)
const ( const (
ChannelTypeUnknown = 0 ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1 ChannelTypeOpenAI = 1
@ -235,62 +212,35 @@ const (
ChannelTypeMoonshot = 25 ChannelTypeMoonshot = 25
ChannelTypeZhipu_v4 = 26 ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27 ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeDummy // this one is only for count, do not add any channel after this
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
"", // 0 "", // 0
"https://api.openai.com", // 1 "https://api.openai.com", // 1
"https://oa.api2d.net", // 2 "https://oa.api2d.net", // 2
"", // 3 "", // 3
"http://localhost:11434", // 4 "http://localhost:11434", // 4
"https://api.openai-sb.com", // 5 "https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6 "https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7 "https://api.ohmygpt.com", // 7
"", // 8 "", // 8
"https://api.caipacity.com", // 9 "https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10 "https://api.aiproxy.io", // 10
"", // 11 "", // 11
"https://api.api2gpt.com", // 12 "https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13 "https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14 "https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15 "https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16 "https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17 "https://dashscope.aliyuncs.com", // 17
"", // 18 "", // 18
"https://ai.360.cn", // 19 "https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20 "https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21 "https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22 "https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23 "https://hunyuan.cloud.tencent.com", //23
"https://generativelanguage.googleapis.com", //24 "https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25 "https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26 "https://open.bigmodel.cn", //26
"https://api.perplexity.ai", //27 "https://api.perplexity.ai", //27
"", //28
"", //29
"", //30
"https://api.lingyiwanwu.com", //31
"", //32
"", //33
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
} }

View File

@ -1,40 +0,0 @@
package common
import (
"errors"
"net/smtp"
"strings"
)
type outlookAuth struct {
username, password string
}
func LoginAuth(username, password string) smtp.Auth {
return &outlookAuth{username, password}
}
func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte{}, nil
}
func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:":
return []byte(a.username), nil
case "Password:":
return []byte(a.password), nil
default:
return nil, errors.New("unknown fromServer")
}
}
return nil, nil
}
func isOutlookServer(server string) bool {
// 兼容多地区的outlook邮箱和ofb邮箱
// 其实应该加一个Option来区分是否用LOGIN的方式登录
// 先临时兼容一下
return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
}

View File

@ -9,31 +9,22 @@ import (
"time" "time"
) )
func generateMessageID() string {
domain := strings.Split(SMTPAccount, "@")[1]
return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain)
}
func SendEmail(subject string, receiver string, content string) error { func SendEmail(subject string, receiver string, content string) error {
if SMTPFrom == "" { // for compatibility if SMTPFrom == "" { // for compatibility
SMTPFrom = SMTPAccount SMTPFrom = SMTPAccount
} }
if SMTPServer == "" && SMTPAccount == "" {
return fmt.Errorf("SMTP 服务器未配置")
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
mail := []byte(fmt.Sprintf("To: %s\r\n"+ mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+ "From: %s<%s>\r\n"+
"Subject: %s\r\n"+ "Subject: %s\r\n"+
"Date: %s\r\n"+ "Date: %s\r\n"+
"Message-ID: %s\r\n"+ // 添加 Message-ID 头
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), generateMessageID(), content)) receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")
var err error var err error
if SMTPPort == 465 || SMTPSSLEnabled { if SMTPPort == 465 {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: SMTPServer, ServerName: SMTPServer,
@ -71,9 +62,6 @@ func SendEmail(subject string, receiver string, content string) error {
if err != nil { if err != nil {
return err return err
} }
} else if isOutlookServer(SMTPAccount) {
auth = LoginAuth(SMTPAccount, SMTPToken)
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
} else { } else {
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
} }

View File

@ -1,38 +0,0 @@
package common
import (
"fmt"
"os"
"strconv"
)
func GetEnvOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetEnvOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
b, err := strconv.ParseBool(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
return defaultValue
}
return b
}

View File

@ -5,37 +5,18 @@ import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"strings"
) )
const KeyRequestBody = "key_request_body"
func GetRequestBody(c *gin.Context) ([]byte, error) {
requestBody, _ := c.Get(KeyRequestBody)
if requestBody != nil {
return requestBody.([]byte), nil
}
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, err
}
_ = c.Request.Body.Close()
c.Set(KeyRequestBody, requestBody)
return requestBody.([]byte), nil
}
func UnmarshalBodyReusable(c *gin.Context, v any) error { func UnmarshalBodyReusable(c *gin.Context, v any) error {
requestBody, err := GetRequestBody(c) requestBody, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
return err return err
} }
contentType := c.Request.Header.Get("Content-Type") err = c.Request.Body.Close()
if strings.HasPrefix(contentType, "application/json") { if err != nil {
err = json.Unmarshal(requestBody, &v) return err
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
} }
err = json.Unmarshal(requestBody, &v)
if err != nil { if err != nil {
return err return err
} }

View File

@ -3,7 +3,6 @@ package common
import ( import (
"fmt" "fmt"
"runtime/debug" "runtime/debug"
"time"
) )
func SafeGoroutine(f func()) { func SafeGoroutine(f func()) {
@ -17,7 +16,7 @@ func SafeGoroutine(f func()) {
}() }()
} }
func SafeSendBool(ch chan bool, value bool) (closed bool) { func SafeSend(ch chan bool, value bool) (closed bool) {
defer func() { defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed. // Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil { if recover() != nil {
@ -31,36 +30,3 @@ func SafeSendBool(ch chan bool, value bool) (closed bool) {
// If the code reaches here, then the channel was not closed. // If the code reaches here, then the channel was not closed.
return false return false
} }
func SafeSendString(ch chan string, value string) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = true
}
}()
// This will panic if the channel is closed.
ch <- value
// If the code reaches here, then the channel was not closed.
return false
}
// SafeSendStringTimeout send, return true, else return false
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = false
}
}()
// This will panic if the channel is closed.
select {
case ch <- value:
return true
case <-time.After(time.Duration(timeout) * time.Second):
return false
}
}

View File

@ -1,8 +1,6 @@
package common package common
import ( import "encoding/json"
"encoding/json"
)
var GroupRatio = map[string]float64{ var GroupRatio = map[string]float64{
"default": 1, "default": 1,

View File

@ -1,84 +0,0 @@
package common
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"math/rand"
"time"
)
func Sha256Raw(data string) []byte {
h := sha256.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1Raw(data []byte) []byte {
h := sha1.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1(data string) string {
return hex.EncodeToString(Sha1Raw([]byte(data)))
}
func HmacSha256Raw(message, key []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(message)
return h.Sum(nil)
}
func HmacSha256(message, key string) string {
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
}
func RandomBytes(length int) []byte {
rand.Seed(time.Now().UnixNano())
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return b
}
func RandomString(length int) string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomHex(length int) string {
const chars = "abcdef0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomNumber(length int) string {
const chars = "0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomUUID() string {
all := RandomHex(32)
return all[:8] + "-" + all[8:12] + "-" + all[12:16] + "-" + all[16:20] + "-" + all[20:]
}

View File

@ -8,6 +8,7 @@ import (
"golang.org/x/image/webp" "golang.org/x/image/webp"
"image" "image"
"io" "io"
"net/http"
"strings" "strings"
) )
@ -31,7 +32,7 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
} }
func IsImageUrl(url string) (bool, error) { func IsImageUrl(url string) (bool, error) {
resp, err := ProxiedHttpHead(url, OutProxyUrl) resp, err := http.Head(url)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -47,13 +48,10 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
if !isImage { if !isImage {
return return
} }
resp, err := ProxiedHttpGet(url, OutProxyUrl) resp, err := http.Get(url)
if err != nil { if err != nil {
return return
} }
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return
}
defer resp.Body.Close() defer resp.Body.Close()
buffer := bytes.NewBuffer(nil) buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body) _, err = buffer.ReadFrom(resp.Body)
@ -66,18 +64,13 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
} }
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := ProxiedHttpGet(imageUrl, OutProxyUrl) response, err := http.Get(imageUrl)
if err != nil { if err != nil {
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err return image.Config{}, "", err
} }
defer response.Body.Close() defer response.Body.Close()
if response.StatusCode != 200 {
err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status))
return image.Config{}, "", err
}
var readData []byte var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))

View File

@ -98,11 +98,3 @@ func LogQuota(quota int) string {
return fmt.Sprintf("%d 点额度", quota) return fmt.Sprintf("%d 点额度", quota)
} }
} }
func LogQuotaF(quota float64) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", quota/QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", int64(quota))
}
}

View File

@ -3,194 +3,100 @@ package common
import ( import (
"encoding/json" "encoding/json"
"strings" "strings"
"sync" "time"
) )
// from songquanpeng/one-api // ModelRatio
const (
USD2RMB = 7.3 // 暂定 1 USD = 7.3 RMB
USD = 500 // $0.002 = 1 -> $1 = 500
RMB = USD / USD2RMB
)
// modelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
// https://openai.com/pricing // https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here // TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens // 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens // 1 === ¥0.014 / 1k tokens
var DefaultModelRatio = map[string]float64{
var defaultModelRatio = map[string]float64{
//"midjourney": 50, //"midjourney": 50,
"gpt-4-gizmo-*": 15, "gpt-4-gizmo-*": 15,
"g-*": 15, "gpt-4": 15,
"gpt-4": 15, "gpt-4-0314": 15,
"gpt-4-0314": 15, "gpt-4-0613": 15,
"gpt-4-0613": 15, "gpt-4-32k": 30,
"gpt-4-32k": 30, "gpt-4-32k-0314": 30,
"gpt-4-32k-0314": 30, "gpt-4-32k-0613": 30,
"gpt-4-32k-0613": 30, "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens "gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075, "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4o": 1.25, // $0.005 / 1K tokens "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens "gpt-3.5-turbo-0301": 0.75,
"gpt-4o-2024-11-20": 1.25, // $0.01 / 1K tokens "gpt-3.5-turbo-0613": 0.75,
"o1-preview": 7.5, "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"o1-preview-2024-09-12": 7.5, "gpt-3.5-turbo-16k-0613": 1.5,
"o1-mini": 0.55, // $0.0011 / 1K tokens "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"o1-mini-2024-09-12": 0.55, "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"o3-mini": 0.55, "gpt-3.5-turbo-0125": 0.25,
"o3-mini-2025-01-31": 0.55, "babbage-002": 0.2, // $0.0004 / 1K tokens
"gpt-4-turbo": 5, // $0.01 / 1K tokens "davinci-002": 1, // $0.002 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "text-ada-001": 0.2,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens "text-babbage-001": 0.25,
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens "text-curie-001": 1,
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens "text-davinci-002": 10,
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens "text-davinci-003": 10,
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens "text-davinci-edit-001": 10,
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens "code-davinci-edit-001": 10,
"gpt-3.5-turbo-0301": 0.75, "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"gpt-3.5-turbo-0613": 0.75, "tts-1": 7.5, // 1k characters -> $0.015
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "tts-1-1106": 7.5, // 1k characters -> $0.015
"gpt-3.5-turbo-16k-0613": 1.5, "tts-1-hd": 15, // 1k characters -> $0.03
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "tts-1-hd-1106": 15, // 1k characters -> $0.03
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens "davinci": 10,
"gpt-3.5-turbo-0125": 0.25, "curie": 10,
"babbage-002": 0.2, // $0.0004 / 1K tokens "babbage": 10,
"davinci-002": 1, // $0.002 / 1K tokens "ada": 10,
"text-ada-001": 0.2, "text-embedding-3-small": 0.01,
"text-babbage-001": 0.25, "text-embedding-3-large": 0.065,
"text-curie-001": 1, "text-embedding-ada-002": 0.05,
"text-davinci-002": 10, "text-search-ada-doc-001": 10,
"text-davinci-003": 10, "text-moderation-stable": 0.1,
"text-davinci-edit-001": 10, "text-moderation-latest": 0.1,
"code-davinci-edit-001": 10, "dall-e-2": 8,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "dall-e-3": 16,
"tts-1": 7.5, // 1k characters -> $0.015 "claude-instant-1": 0.4, // $0.8 / 1M tokens
"tts-1-1106": 7.5, // 1k characters -> $0.015 "claude-2.0": 4, // $8 / 1M tokens
"tts-1-hd": 15, // 1k characters -> $0.03 "claude-2.1": 4, // $8 / 1M tokens
"tts-1-hd-1106": 15, // 1k characters -> $0.03 "claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"davinci": 10, "claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"curie": 10, "claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"babbage": 10, "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ada": 10, "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"text-embedding-3-small": 0.01, "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"text-embedding-3-large": 0.065, "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"text-embedding-ada-002": 0.05, "PaLM-2": 1,
"text-search-ada-doc-001": 10, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"text-moderation-stable": 0.1, "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"text-moderation-latest": 0.1, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"claude-instant-1": 0.4, // $0.8 / 1M tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"claude-2.0": 4, // $8 / 1M tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"claude-2.1": 4, // $8 / 1M tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"claude-3-7-sonnet-20250219": 1.5, "glm-4": 7.143, // ¥0.1 / 1k tokens
"claude-3-7-sonnet-20250219-thinking": 1.5, "glm-4v": 7.143, // ¥0.1 / 1k tokens
"claude-3-5-haiku-20241022": 0.4, "glm-3-turbo": 0.3572,
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens "qwen-plus": 10, // ¥0.14 / 1k tokens
"claude-3-5-sonnet-20241022": 1.5, // $3 / 1M tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"ERNIE-4.0-8K": 0.120 * RMB, "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"ERNIE-3.5-8K": 0.012 * RMB, "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"ERNIE-3.5-8K-0205": 0.024 * RMB, "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-3.5-8K-1222": 0.012 * RMB, "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"ERNIE-Bot-8K": 0.024 * RMB, "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"ERNIE-3.5-4K-0205": 0.012 * RMB, "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"ERNIE-Speed-8K": 0.004 * RMB, "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.18,
"yi-34b-chat-200k": 0.864,
"yi-vl-plus": 0.432,
"yi-large": 20.0 / 1000 * RMB,
"yi-medium": 2.5 / 1000 * RMB,
"yi-vision": 6.0 / 1000 * RMB,
"yi-medium-200k": 12.0 / 1000 * RMB,
"yi-spark": 1.0 / 1000 * RMB,
"yi-large-rag": 25.0 / 1000 * RMB,
"yi-large-turbo": 12.0 / 1000 * RMB,
"yi-large-preview": 20.0 / 1000 * RMB,
"yi-large-rag-preview": 25.0 / 1000 * RMB,
"command": 0.5,
"command-nightly": 0.5,
"command-light": 0.5,
"command-light-nightly": 0.5,
"command-r": 0.25,
"command-r-plus": 1.5,
"command-r-08-2024": 0.075,
"command-r-plus-08-2024": 1.25,
"deepseek-chat": 0.07,
"deepseek-coder": 0.07,
// Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用
"llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD,
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
"llama-3-sonar-large-32k-chat": 1 / 1000 * USD,
"llama-3-sonar-large-32k-online": 1 / 1000 * USD,
} }
var defaultModelPrice = map[string]float64{ var DefaultModelPrice = map[string]float64{
"suno_music": 0.1,
"suno_lyrics": 0.01,
"dall-e-2": 0.02,
"dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1, "gpt-4-gizmo-*": 0.1,
"g-*": 0.1,
"mj_imagine": 0.1, "mj_imagine": 0.1,
"mj_variation": 0.1, "mj_variation": 0.1,
"mj_reroll": 0.1, "mj_reroll": 0.1,
@ -206,38 +112,16 @@ var defaultModelPrice = map[string]float64{
"mj_describe": 0.05, "mj_describe": 0.05,
"mj_upscale": 0.05, "mj_upscale": 0.05,
"swap_face": 0.05, "swap_face": 0.05,
"mj_upload": 0.05,
} }
var ( var ModelPrice = map[string]float64{}
modelPriceMap map[string]float64 = nil var ModelRatio = map[string]float64{}
modelPriceMapMutex = sync.RWMutex{}
)
var (
modelRatioMap map[string]float64 = nil
modelRatioMapMutex = sync.RWMutex{}
)
var CompletionRatio map[string]float64 = nil
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
"g-*": 2,
"gpt-4-all": 2,
"gpt-4o-all": 2,
}
func GetModelPriceMap() map[string]float64 {
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
if modelPriceMap == nil {
modelPriceMap = defaultModelPrice
}
return modelPriceMap
}
func ModelPrice2JSONString() string { func ModelPrice2JSONString() string {
GetModelPriceMap() if len(ModelPrice) == 0 {
jsonBytes, err := json.Marshal(modelPriceMap) ModelPrice = DefaultModelPrice
}
jsonBytes, err := json.Marshal(ModelPrice)
if err != nil { if err != nil {
SysError("error marshalling model price: " + err.Error()) SysError("error marshalling model price: " + err.Error())
} }
@ -245,42 +129,32 @@ func ModelPrice2JSONString() string {
} }
func UpdateModelPriceByJSONString(jsonStr string) error { func UpdateModelPriceByJSONString(jsonStr string) error {
modelPriceMapMutex.Lock() ModelPrice = make(map[string]float64)
defer modelPriceMapMutex.Unlock() return json.Unmarshal([]byte(jsonStr), &ModelPrice)
modelPriceMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
} }
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false func GetModelPrice(name string, printErr bool) float64 {
func GetModelPrice(name string, printErr bool) (float64, bool) { if len(ModelPrice) == 0 {
GetModelPriceMap() ModelPrice = DefaultModelPrice
}
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
} }
price, ok := modelPriceMap[name] price, ok := ModelPrice[name]
if !ok { if !ok {
if printErr { if printErr {
SysError("model price not found: " + name) SysError("model price not found: " + name)
} }
return -1, false return -1
} }
return price, true return price
}
func GetModelRatioMap() map[string]float64 {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
if modelRatioMap == nil {
modelRatioMap = defaultModelRatio
}
return modelRatioMap
} }
func ModelRatio2JSONString() string { func ModelRatio2JSONString() string {
GetModelRatioMap() if len(ModelRatio) == 0 {
jsonBytes, err := json.Marshal(modelRatioMap) ModelRatio = DefaultModelRatio
}
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil { if err != nil {
SysError("error marshalling model ratio: " + err.Error()) SysError("error marshalling model ratio: " + err.Error())
} }
@ -288,20 +162,18 @@ func ModelRatio2JSONString() string {
} }
func UpdateModelRatioByJSONString(jsonStr string) error { func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock() ModelRatio = make(map[string]float64)
defer modelRatioMapMutex.Unlock() return json.Unmarshal([]byte(jsonStr), &ModelRatio)
modelRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
} }
func GetModelRatio(name string) float64 { func GetModelRatio(name string) float64 {
GetModelRatioMap() if len(ModelRatio) == 0 {
ModelRatio = DefaultModelRatio
}
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
} }
ratio, ok := modelRatioMap[name] ratio, ok := ModelRatio[name]
if !ok { if !ok {
SysError("model ratio not found: " + name) SysError("model ratio not found: " + name)
return 30 return 30
@ -309,40 +181,7 @@ func GetModelRatio(name string) float64 {
return ratio return ratio
} }
func DefaultModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(defaultModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
func CompletionRatio2JSONString() string {
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
SysError("error marshalling completion ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 { func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
}
if strings.HasPrefix(name, "gpt-3.5") { if strings.HasPrefix(name, "gpt-3.5") {
if strings.HasSuffix(name, "0125") { if strings.HasSuffix(name, "0125") {
return 3 return 3
@ -350,32 +189,23 @@ func GetCompletionRatio(name string) float64 {
if strings.HasSuffix(name, "1106") { if strings.HasSuffix(name, "1106") {
return 2 return 2
} }
if name == "gpt-3.5-turbo" { if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
return 3 // TODO: clear this after 2023-12-11
now := time.Now()
// https://platform.openai.com/docs/models/continuous-model-upgrades
// if after 2023-12-11, use 2
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
return 2
}
} }
return 1.333333
return 4.0 / 3.0
} }
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" { if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || "gpt-4o-2024-05-13" == name { if strings.HasSuffix(name, "preview") {
return 3 return 3
} }
if strings.HasPrefix(name, "gpt-4o") {
return 4
}
return 2 return 2
} }
if "o1" == name || strings.HasPrefix(name, "o1-") {
return 4
}
if "o3" == name || strings.HasPrefix(name, "o3-") {
return 4
}
if name == "chatgpt-4o-latest" {
return 3
}
if strings.HasPrefix(name, "claude-instant-1") { if strings.HasPrefix(name, "claude-instant-1") {
return 3 return 3
} else if strings.HasPrefix(name, "claude-2") { } else if strings.HasPrefix(name, "claude-2") {
@ -383,55 +213,5 @@ func GetCompletionRatio(name string) float64 {
} else if strings.HasPrefix(name, "claude-3") { } else if strings.HasPrefix(name, "claude-3") {
return 5 return 5
} }
if strings.HasPrefix(name, "mistral-") {
return 3
}
if strings.HasPrefix(name, "gemini-") {
return 4
}
if strings.HasPrefix(name, "command") {
switch name {
case "command-r":
return 3
case "command-r-plus":
return 5
case "command-r-08-2024":
return 4
case "command-r-plus-08-2024":
return 4
default:
return 2
}
}
if strings.HasPrefix(name, "deepseek") {
return 2
}
if strings.HasPrefix(name, "ERNIE-Speed-") {
return 2
} else if strings.HasPrefix(name, "ERNIE-Lite-") {
return 2
} else if strings.HasPrefix(name, "ERNIE-Character") {
return 2
} else if strings.HasPrefix(name, "ERNIE-Functions") {
return 2
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.64
case "llama3-8b-8192":
return 2
case "llama3-70b-8192":
return 0.79 / 0.59
}
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}
return 1 return 1
} }
func GetCompletionRatioMap() map[string]float64 {
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio
}

View File

@ -18,8 +18,9 @@ func InitRedisClient() (err error) {
return nil return nil
} }
if os.Getenv("SYNC_FREQUENCY") == "" { if os.Getenv("SYNC_FREQUENCY") == "" {
SysLog("SYNC_FREQUENCY not set, use default value 60") RedisEnabled = false
SyncFrequency = 60 SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
} }
SysLog("Redis is enabled") SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))

View File

@ -1,70 +1,50 @@
package common package common
import ( func SundaySearch(text string, pattern string) bool {
"encoding/json" // 计算偏移表
"math/rand" offset := make(map[rune]int)
"strconv" for i, c := range pattern {
"unsafe" offset[c] = len(pattern) - i
)
func GetStringIfEmpty(str string, defaultValue string) string {
if str == "" {
return defaultValue
} }
return str
}
func GetRandomString(length int) string { // 文本串长度和模式串长度
//rand.Seed(time.Now().UnixNano()) n, m := len(text), len(pattern)
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func MapToJsonStr(m map[string]interface{}) string { // 主循环i表示当前对齐的文本串位置
bytes, err := json.Marshal(m) for i := 0; i <= n-m; {
if err != nil { // 检查子串
return "" j := 0
} for j < m && text[i+j] == pattern[j] {
return string(bytes) j++
} }
// 如果完全匹配,返回匹配位置
func StrToMap(str string) map[string]interface{} { if j == m {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
if err != nil {
return nil
}
return m
}
func IsJsonStr(str string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}
func StringsContains(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true return true
} }
// 如果还有剩余字符,则检查下一位字符在偏移表中的值
if i+m < n {
next := rune(text[i+m])
if val, ok := offset[next]; ok {
i += val // 存在于偏移表中,进行跳跃
} else {
i += len(pattern) + 1 // 不存在于偏移表中,跳过整个模式串长度
}
} else {
break
}
} }
return false return false // 如果没有找到匹配,返回-1
} }
// StringToByteSlice []byte only read, panic on append func RemoveDuplicate(s []string) []string {
func StringToByteSlice(s string) []byte { result := make([]string, 0, len(s))
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) temp := map[string]struct{}{}
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} for _, item := range s {
return *(*[]byte)(unsafe.Pointer(&tmp2)) if _, ok := temp[item]; !ok {
temp[item] = struct{}{}
result = append(result, item)
}
}
return result
} }

View File

@ -1,8 +1,6 @@
package common package common
import ( import "encoding/json"
"encoding/json"
)
var TopupGroupRatio = map[string]float64{ var TopupGroupRatio = map[string]float64{
"default": 1, "default": 1,

View File

@ -1,46 +0,0 @@
package common
import (
"encoding/json"
)
var UserUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(UserUsableGroups)
if err != nil {
SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
UserUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
}
func GetUserUsableGroups(userGroup string) map[string]string {
if userGroup == "" {
// 如果userGroup为空返回UserUsableGroups
return UserUsableGroups
}
// 如果userGroup不在UserUsableGroups中返回UserUsableGroups + userGroup
if _, ok := UserUsableGroups[userGroup]; !ok {
appendUserUsableGroups := make(map[string]string)
for k, v := range UserUsableGroups {
appendUserUsableGroups[k] = v
}
appendUserUsableGroups[userGroup] = "用户分组"
return appendUserUsableGroups
}
// 如果userGroup在UserUsableGroups中返回UserUsableGroups
return UserUsableGroups
}
func GroupInUserUsableGroups(groupName string) bool {
_, ok := UserUsableGroups[groupName]
return ok
}

View File

@ -1,25 +1,19 @@
package common package common
import ( import (
"context"
"errors"
crand "crypto/rand"
"encoding/base64"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/net/proxy"
"html/template" "html/template"
"log" "log"
"math/big"
"math/rand" "math/rand"
"net" "net"
"net/http" "os"
"net/url"
"os/exec" "os/exec"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"unsafe"
) )
func OpenBrowser(url string) { func OpenBrowser(url string) {
@ -136,11 +130,6 @@ func IntMax(a int, b int) int {
} }
} }
func IsIP(s string) bool {
ip := net.ParseIP(s)
return ip != nil
}
func GetUUID() string { func GetUUID() string {
code := uuid.New().String() code := uuid.New().String()
code = strings.Replace(code, "-", "", -1) code = strings.Replace(code, "-", "", -1)
@ -150,35 +139,33 @@ func GetUUID() string {
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() { func init() {
rand.New(rand.NewSource(time.Now().UnixNano())) rand.Seed(time.Now().UnixNano())
} }
func GenerateRandomCharsKey(length int) (string, error) { func GenerateKey() string {
b := make([]byte, length)
maxI := big.NewInt(int64(len(keyChars)))
for i := range b {
n, err := crand.Int(crand.Reader, maxI)
if err != nil {
return "", err
}
b[i] = keyChars[n.Int64()]
}
return string(b), nil
}
func GenerateRandomKey(length int) (string, error) {
bytes := make([]byte, length*3/4) // 对于48位的输出这里应该是36
if _, err := crand.Read(bytes); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(bytes), nil
}
func GenerateKey() (string, error) {
//rand.Seed(time.Now().UnixNano()) //rand.Seed(time.Now().UnixNano())
return GenerateRandomCharsKey(48) key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
//rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
} }
func GetRandomInt(max int) int { func GetRandomInt(max int) int {
@ -203,64 +190,49 @@ func Max(a int, b int) int {
} }
} }
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func MessageWithRequestId(message string, id string) string { func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id) return fmt.Sprintf("%s (request id: %s)", message, id)
} }
func RandomSleep() { func String2Int(str string) int {
// Sleep for 0-3000 ms num, err := strconv.Atoi(str)
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) if err != nil {
return 0
}
return num
} }
func GetProxiedHttpClient(proxyUrl string) (*http.Client, error) { func StringsContains(strs []string, str string) bool {
if "" == proxyUrl { for _, s := range strs {
return &http.Client{}, nil if s == str {
} return true
u, err := url.Parse(proxyUrl)
if err != nil {
return nil, err
}
if strings.HasPrefix(proxyUrl, "http") {
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(u),
},
}, nil
} else if strings.HasPrefix(proxyUrl, "socks") {
dialer, err := proxy.FromURL(u, proxy.Direct)
if err != nil {
return nil, err
} }
return &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
},
},
}, nil
} }
return false
return nil, errors.New("unsupported proxy type")
} }
func ProxiedHttpGet(url, proxyUrl string) (*http.Response, error) { // StringToByteSlice []byte only read, panic on append
client, err := GetProxiedHttpClient(proxyUrl) func StringToByteSlice(s string) []byte {
if err != nil { tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
return nil, err tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
} return *(*[]byte)(unsafe.Pointer(&tmp2))
return client.Get(url)
}
func ProxiedHttpHead(url, proxyUrl string) (*http.Response, error) {
client, err := GetProxiedHttpClient(proxyUrl)
if err != nil {
return nil, err
}
return client.Head(url)
} }

View File

@ -1,35 +0,0 @@
package constant
import (
"encoding/json"
"one-api/common"
)
var Chats = []map[string]string{
{
"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
},
{
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
},
{
"AMA 问天": "ama://set-api-key?server={address}&key={key}",
},
{
"OpenCat": "opencat://team/join?domain={address}&token={key}",
},
}
func UpdateChatsByJsonString(jsonString string) error {
Chats = make([]map[string]string, 0)
return json.Unmarshal([]byte(jsonString), &Chats)
}
func Chats2JsonString() string {
jsonBytes, err := json.Marshal(Chats)
if err != nil {
common.SysError("error marshalling chats: " + err.Error())
return "[]"
}
return string(jsonBytes)
}

View File

@ -1,51 +0,0 @@
package constant
import (
"fmt"
"one-api/common"
"os"
"strings"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
// ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-pro-exp-0827": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash-exp-0827": "v1beta",
"gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
}
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
}
// 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

View File

@ -1,10 +1,6 @@
package constant package constant
var MjNotifyEnabled = false var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true
var MjActionCheckSuccessEnabled = true
const ( const (
MjErrorUnknown = 5 MjErrorUnknown = 5
@ -27,7 +23,6 @@ const (
MjActionLowVariation = "LOW_VARIATION" MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN" MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE" MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
) )
var MidjourneyModel2Action = map[string]string{ var MidjourneyModel2Action = map[string]string{
@ -46,5 +41,4 @@ var MidjourneyModel2Action = map[string]string{
"mj_low_variation": MjActionLowVariation, "mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan, "mj_pan": MjActionPan,
"swap_face": MjActionSwapFace, "swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
} }

View File

@ -1,8 +0,0 @@
package constant
var PayAddress = ""
var CustomCallbackAddress = ""
var EpayId = ""
var EpayKey = ""
var Price = 7.3
var MinTopUp = 1

View File

@ -4,8 +4,7 @@ import "strings"
var CheckSensitiveEnabled = true var CheckSensitiveEnabled = true
var CheckSensitiveOnPromptEnabled = true var CheckSensitiveOnPromptEnabled = true
var CheckSensitiveOnCompletionEnabled = true
//var CheckSensitiveOnCompletionEnabled = true
// StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词 // StopOnSensitiveEnabled 如果检测到敏感词,是否立刻停止生成,否则替换敏感词
var StopOnSensitiveEnabled = true var StopOnSensitiveEnabled = true
@ -16,7 +15,7 @@ var StreamCacheQueueLength = 0
// SensitiveWords 敏感词 // SensitiveWords 敏感词
// var SensitiveWords []string // var SensitiveWords []string
var SensitiveWords = []string{ var SensitiveWords = []string{
"test_sensitive", "test",
} }
func SensitiveWordsToString() string { func SensitiveWordsToString() string {
@ -38,6 +37,6 @@ func ShouldCheckPromptSensitive() bool {
return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled return CheckSensitiveEnabled && CheckSensitiveOnPromptEnabled
} }
//func ShouldCheckCompletionSensitive() bool { func ShouldCheckCompletionSensitive() bool {
// return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled return CheckSensitiveEnabled && CheckSensitiveOnCompletionEnabled
//} }

View File

@ -1,18 +0,0 @@
package constant
type TaskPlatform string
const (
TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj"
)
const (
SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS"
)
var SunoModel2Action = map[string]string{
"suno_music": SunoActionMusic,
"suno_lyrics": SunoActionLyrics,
}

View File

@ -5,36 +5,29 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"io" "io"
"math"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/middleware"
"one-api/model" "one-api/model"
"one-api/relay" "one-api/relay"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/service" "one-api/service"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
tik := time.Now()
if channel.Type == common.ChannelTypeMidjourney { if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil return errors.New("midjourney channel test is not supported"), nil
} }
if channel.Type == common.ChannelTypeSunoAPI { common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
return errors.New("suno channel test is not supported"), nil
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{ c.Request = &http.Request{
@ -43,52 +36,40 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
Body: nil, Body: nil,
Header: make(http.Header), Header: make(http.Header),
} }
if testModel == "" {
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0]
} else {
testModel = "gpt-3.5-turbo"
}
}
} else {
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
}
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key) c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json") c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
middleware.SetupContextForSelectedChannel(c, channel, testModel) case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
meta := relaycommon.GenRelayInfo(c) meta := relaycommon.GenRelayInfo(c)
apiType, _ := constant.ChannelType2APIType(channel.Type) apiType := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
if adaptor == nil { if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
if testModel == "" {
request := buildTestRequest(testModel) testModel = adaptor.GetModelList()[0]
meta.UpstreamModelName = testModel
}
request := buildTestRequest()
request.Model = testModel
meta.UpstreamModelName = testModel meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
adaptor.Init(meta) adaptor.Init(meta, *request)
convertedRequest, err := adaptor.ConvertRequest(c, meta, request) convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
if err != nil { if err != nil {
return err, nil return err, nil
} }
@ -102,67 +83,43 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if err != nil { if err != nil {
return err, nil return err, nil
} }
if resp != nil && resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(resp) err := relaycommon.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
} }
usage, respErr := adaptor.DoResponse(c, resp, meta) usage, respErr, _ := adaptor.DoResponse(c, resp, meta)
if respErr != nil { if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), respErr return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
} }
if usage == nil { if usage == nil {
return errors.New("usage is nil"), nil return errors.New("usage is nil"), nil
} }
result := w.Result() result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body) respBody, err := io.ReadAll(result.Body)
if err != nil { if err != nil {
return err, nil return err, nil
} }
modelPrice, usePrice := common.GetModelPrice(testModel, false)
modelRatio := common.GetModelRatio(testModel)
completionRatio := common.GetCompletionRatio(testModel)
ratio := modelRatio
quota := 0
if !usePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit)
}
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil return nil, nil
} }
func buildTestRequest(model string) *dto.GeneralOpenAIRequest { func buildTestRequest() *dto.GeneralOpenAIRequest {
testRequest := &dto.GeneralOpenAIRequest{ testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later Model: "", // this will be set later
Stream: false, MaxTokens: 1,
}
if "o1" == model || strings.HasPrefix(model, "o1-") {
testRequest.MaxCompletionTokens = 1
} else {
testRequest.MaxTokens = 1
} }
content, _ := json.Marshal("hi") content, _ := json.Marshal("hi")
testMessage := dto.Message{ testMessage := dto.Message{
Role: "user", Role: "user",
Content: content, Content: content,
} }
testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage) testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest return testRequest
} }
func TestChannel(c *gin.Context) { func TestChannel(c *gin.Context) {
channelId, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -170,7 +127,7 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
channel, err := model.GetChannelById(channelId, true) channel, err := model.GetChannelById(id, true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -223,38 +180,33 @@ func testAllChannels(notify bool) error {
if disableThreshold == 0 { if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value disableThreshold = 10000000 // a impossible value
} }
gopool.Go(func() { go func() {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, openaiWithStatusErr := testChannel(channel, "") err, openaiErr := testChannel(channel, "")
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
shouldBanChannel := false ban := false
// request error disables the channel
if openaiWithStatusErr != nil {
oaiErr := openaiWithStatusErr.Error
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
}
if milliseconds > disableThreshold { if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
shouldBanChannel = true ban = true
} }
if openaiErr != nil {
// disable channel err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { ban = true
}
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error()) service.DisableChannel(channel.Id, channel.Name, err.Error())
} }
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
// enable channel
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name) service.EnableChannel(channel.Id, channel.Name)
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval) time.Sleep(common.RequestInterval)
} }
@ -267,7 +219,7 @@ func testAllChannels(notify bool) error {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
} }
} }
}) }()
return nil return nil
} }

View File

@ -1,8 +1,6 @@
package controller package controller
import ( import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
@ -11,34 +9,6 @@ import (
"strings" "strings"
) )
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
func GetAllChannels(c *gin.Context) { func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size")) pageSize, _ := strconv.Atoi(c.Query("page_size"))
@ -65,65 +35,6 @@ func GetAllChannels(c *gin.Context) {
return return
} }
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if channel.Type != common.ChannelTypeOpenAI {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "仅支持 OpenAI 类型渠道",
})
return
}
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
result := OpenAIModelsResponse{}
err = json.Unmarshal(body, &result)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
if !result.Success {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "上游返回错误",
})
}
var ids []string
for _, model := range result.Data {
ids = append(ids, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ids,
})
}
func FixChannelsAbilities(c *gin.Context) { func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility() count, err := model.FixAbility()
if err != nil { if err != nil {
@ -198,28 +109,6 @@ func AddChannel(c *gin.Context) {
} }
channel.CreatedTime = common.GetTimestamp() channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n") keys := strings.Split(channel.Key, "\n")
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
keys = []string{channel.Key}
}
channels := make([]model.Channel, 0, len(keys)) channels := make([]model.Channel, 0, len(keys))
for _, key := range keys { for _, key := range keys {
if key == "" { if key == "" {
@ -319,27 +208,6 @@ func UpdateChannel(c *gin.Context) {
}) })
return return
} }
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
}
err = channel.Update() err = channel.Update()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@ -112,9 +112,7 @@ func GitHubOAuth(c *gin.Context) {
user := model.User{ user := model.User{
GitHubId: githubUser.Login, GitHubId: githubUser.Login,
} }
// IsGitHubIdAlreadyTaken is unscoped
if model.IsGitHubIdAlreadyTaken(user.GitHubId) { if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
// FillUserByGitHubId is scoped
err := user.FillUserByGitHubId() err := user.FillUserByGitHubId()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -123,18 +121,8 @@ func GitHubOAuth(c *gin.Context) {
}) })
return return
} }
// if user.Id == 0 , user has been deleted
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else { } else {
if common.RegisterEnabled { if common.RegisterEnabled {
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" { if githubUser.Name != "" {
user.DisplayName = githubUser.Name user.DisplayName = githubUser.Name
@ -145,7 +133,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = common.RoleCommonUser user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = common.UserStatusEnabled
if err := user.Insert(user.InviterId); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@ -4,7 +4,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model"
) )
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
@ -18,22 +17,3 @@ func GetGroups(c *gin.Context) {
"data": groupNames, "data": groupNames,
}) })
} }
func GetUserGroups(c *gin.Context) {
usableGroups := make(map[string]string)
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.CacheGetUserGroup(userId)
for groupName, _ := range common.GroupRatio {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := common.GetUserUsableGroups(userGroup)
if _, ok := userUsableGroups[groupName]; ok {
usableGroups[groupName] = userUsableGroups[groupName]
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": usableGroups,
})
}

View File

@ -1,247 +0,0 @@
package controller
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
type LinuxDoOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type LinuxDoUser struct {
ID int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
form := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
}
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", "Basic "+auth)
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse LinuxDoOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res2.Body.Close()
var linuxdoUser LinuxDoUser
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
if err != nil {
return nil, err
}
if linuxdoUser.ID == 0 {
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
}
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
return nil, errors.New("用户 LINUX DO 信任等级不足!")
}
return &linuxdoUser, nil
}
func LinuxDoOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
err := user.FillUserByLinuxDoId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
affCode := c.Query("aff")
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
if linuxdoUser.Name != "" {
user.DisplayName = linuxdoUser.Name
} else {
user.DisplayName = linuxdoUser.Username
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 LINUX DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@ -1,19 +1,18 @@
package controller package controller
import ( import (
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"github.com/gin-gonic/gin"
) )
func GetAllLogs(c *gin.Context) { func GetAllLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size")) pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 { if p < 0 {
p = 1 p = 0
} }
if pageSize < 0 { if pageSize < 0 {
pageSize = common.ItemsPerPage pageSize = common.ItemsPerPage
@ -25,7 +24,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) channel, _ := strconv.Atoi(c.Query("channel"))
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel) logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -36,20 +35,16 @@ func GetAllLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": map[string]any{ "data": logs,
"items": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
}) })
return
} }
func GetUserLogs(c *gin.Context) { func GetUserLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size")) pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 { if p < 0 {
p = 1 p = 0
} }
if pageSize < 0 { if pageSize < 0 {
pageSize = common.ItemsPerPage pageSize = common.ItemsPerPage
@ -63,7 +58,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize) logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -74,12 +69,7 @@ func GetUserLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": map[string]any{ "data": logs,
"items": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
}) })
return return
} }
@ -192,7 +182,7 @@ func DeleteHistoryLogs(c *gin.Context) {
}) })
return return
} }
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100) count, err := model.DeleteOldLog(targetTimestamp)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@ -10,11 +10,11 @@ import (
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"strconv" "strconv"
"strings"
"time" "time"
) )
@ -86,7 +86,7 @@ func UpdateMidjourneyTaskBulk() {
continue continue
} }
// 设置超时时间 // 设置超时时间
timeout := time.Second * 15 timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求 // 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx) req = req.WithContext(ctx)
@ -146,26 +146,28 @@ func UpdateMidjourneyTaskBulk() {
buttonStr, _ := json.Marshal(responseItem.Buttons) buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr) task.Buttons = string(buttonStr)
} }
shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%" task.Progress = "100%"
if task.Quota != 0 { err = model.CacheUpdateUserQuota(task.UserId)
shouldReturnQuota = true if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
} }
} }
err = task.Update() err = task.Update()
if err != nil { if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(task.Quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
} }
} }
} }
@ -231,12 +233,6 @@ func GetAllMidjourney(c *gin.Context) {
if logs == nil { if logs == nil {
logs = make([]*model.Midjourney, 0) logs = make([]*model.Midjourney, 0)
} }
if constant.MjForwardUrlEnabled {
for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@ -263,7 +259,7 @@ func GetUserMidjourney(c *gin.Context) {
if logs == nil { if logs == nil {
logs = make([]*model.Midjourney, 0) logs = make([]*model.Midjourney, 0)
} }
if constant.MjForwardUrlEnabled { if !strings.Contains(common.ServerAddress, "localhost") {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney

View File

@ -33,13 +33,10 @@ func GetStatus(c *gin.Context) {
"success": true, "success": true,
"message": "", "message": "",
"data": gin.H{ "data": gin.H{
"version": common.Version,
"start_time": common.StartTime, "start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled, "email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled, "github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId, "github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
"linuxdo_client_id": common.LinuxDoClientId,
"telegram_oauth": common.TelegramOAuthEnabled, "telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName, "telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName, "system_name": common.SystemName,
@ -48,7 +45,7 @@ func GetStatus(c *gin.Context) {
"wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled, "wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress, "server_address": common.ServerAddress,
"stripe_unit_price": common.StripeUnitPrice, "price": common.Price,
"min_topup": common.MinTopUp, "min_topup": common.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey, "turnstile_site_key": common.TurnstileSiteKey,
@ -59,13 +56,11 @@ func GetStatus(c *gin.Context) {
"display_in_currency": common.DisplayInCurrencyEnabled, "display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled, "enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled, "enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled, "enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime, "data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar, "default_collapse_sidebar": common.DefaultCollapseSidebar,
"payment_enabled": common.PaymentEnabled, "enable_online_topup": common.PayAddress != "" && common.EpayId != "" && common.EpayKey != "",
"mj_notify_enabled": constant.MjNotifyEnabled, "mj_notify_enabled": constant.MjNotifyEnabled,
"chats": constant.Chats,
}, },
}) })
return return
@ -124,20 +119,10 @@ func SendEmailVerification(c *gin.Context) {
}) })
return return
} }
parts := strings.Split(email, "@")
if len(parts) != 2 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
localPart := parts[0]
domainPart := parts[1]
if common.EmailDomainRestrictionEnabled { if common.EmailDomainRestrictionEnabled {
allowed := false allowed := false
for _, domain := range common.EmailDomainWhitelist { for _, domain := range common.EmailDomainWhitelist {
if domainPart == domain { if strings.HasSuffix(email, "@"+domain) {
allowed = true allowed = true
break break
} }
@ -145,22 +130,11 @@ func SendEmailVerification(c *gin.Context) {
if !allowed { if !allowed {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.", "message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
}) })
return return
} }
} }
if common.EmailAliasRestrictionEnabled {
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Contains(localPart, ".")
if containsSpecialSymbols {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
})
return
}
}
if model.IsEmailAlreadyTaken(email) { if model.IsEmailAlreadyTaken(email) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@ -4,28 +4,48 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
"one-api/relay" "one-api/relay"
"one-api/relay/channel/ai360" "one-api/relay/channel/ai360"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot" "one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
) )
// https://platform.openai.com/docs/api-reference/models/list // https://platform.openai.com/docs/api-reference/models/list
var openAIModels []dto.OpenAIModels type OpenAIModelPermission struct {
var openAIModelsMap map[string]dto.OpenAIModels Id string `json:"id"`
var channelId2Models map[int][]string Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
func getPermission() []dto.OpenAIModelPermission { type OpenAIModels struct {
var permission []dto.OpenAIModelPermission Id string `json:"id"`
permission = append(permission, dto.OpenAIModelPermission{ Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
func init() {
var permission []OpenAIModelPermission
permission = append(permission, OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ", Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission", Object: "model_permission",
Created: 1626777600, Created: 1626777600,
@ -39,12 +59,7 @@ func getPermission() []dto.OpenAIModelPermission {
Group: nil, Group: nil,
IsBlocking: false, IsBlocking: false,
}) })
return permission
}
func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
permission := getPermission()
for i := 0; i < relayconstant.APITypeDummy; i++ { for i := 0; i < relayconstant.APITypeDummy; i++ {
if i == relayconstant.APITypeAIProxyLibrary { if i == relayconstant.APITypeAIProxyLibrary {
continue continue
@ -53,7 +68,7 @@ func init() {
channelName := adaptor.GetChannelName() channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList() modelNames := adaptor.GetModelList()
for _, modelName := range modelNames { for _, modelName := range modelNames {
openAIModels = append(openAIModels, dto.OpenAIModels{ openAIModels = append(openAIModels, OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
@ -65,51 +80,29 @@ func init() {
} }
} }
for _, modelName := range ai360.ModelList { for _, modelName := range ai360.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{ openAIModels = append(openAIModels, OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: ai360.ChannelName, OwnedBy: "360",
Permission: permission, Permission: permission,
Root: modelName, Root: modelName,
Parent: nil, Parent: nil,
}) })
} }
for _, modelName := range moonshot.ModelList { for _, modelName := range moonshot.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{ openAIModels = append(openAIModels, OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: moonshot.ChannelName, OwnedBy: "moonshot",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range lingyiwanwu.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: lingyiwanwu.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range minimax.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: minimax.ChannelName,
Permission: permission, Permission: permission,
Root: modelName, Root: modelName,
Parent: nil, Parent: nil,
}) })
} }
for modelName, _ := range constant.MidjourneyModel2Action { for modelName, _ := range constant.MidjourneyModel2Action {
openAIModels = append(openAIModels, dto.OpenAIModels{ openAIModels = append(openAIModels, OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
@ -119,107 +112,46 @@ func init() {
Parent: nil, Parent: nil,
}) })
} }
openAIModelsMap = make(map[string]dto.OpenAIModels) openAIModelsMap = make(map[string]OpenAIModels)
for _, aiModel := range openAIModels { for _, model := range openAIModels {
openAIModelsMap[aiModel.Id] = aiModel openAIModelsMap[model.Id] = model
}
channelId2Models = make(map[int][]string)
for i := 1; i <= common.ChannelTypeDummy; i++ {
apiType, success := relayconstant.ChannelType2APIType(i)
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
continue
}
meta := &relaycommon.RelayInfo{ChannelType: i}
adaptor := relay.GetAdaptor(apiType)
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
} }
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {
userOpenAiModels := make([]dto.OpenAIModels, 0) userId := c.GetInt("id")
permission := getPermission() user, err := model.GetUserById(userId, true)
if err != nil {
modelLimitEnable := c.GetBool("token_model_limit_enabled") c.JSON(http.StatusOK, gin.H{
if modelLimitEnable { "success": false,
s, ok := c.Get("token_model_limit") "message": err.Error(),
var tokenModelLimit map[string]bool })
if ok { return
tokenModelLimit = s.(map[string]bool) }
} else { models := model.GetGroupModels(user.Group)
tokenModelLimit = map[string]bool{} userOpenAiModels := make([]OpenAIModels, 0)
} for _, s := range models {
for allowModel, _ := range tokenModelLimit { if _, ok := openAIModelsMap[s]; ok {
if _, ok := openAIModelsMap[allowModel]; ok { userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: allowModel,
Parent: nil,
})
}
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "get user group failed",
})
return
}
group := userGroup
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
group = tokenGroup
}
models := model.GetGroupModels(group)
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: s,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: s,
Parent: nil,
})
}
} }
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": true, "object": "list",
"data": userOpenAiModels, "data": userOpenAiModels,
}) })
} }
func ChannelListModels(c *gin.Context) { func ChannelListModels(c *gin.Context) {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": true, "object": "list",
"data": openAIModels, "data": openAIModels,
})
}
func DashboardListModels(c *gin.Context) {
c.JSON(200, gin.H{
"success": true,
"data": channelId2Models,
}) })
} }
func RetrieveModel(c *gin.Context) { func RetrieveModel(c *gin.Context) {
modelId := c.Param("model") modelId := c.Param("model")
if aiModel, ok := openAIModelsMap[modelId]; ok { if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, aiModel) c.JSON(200, model)
} else { } else {
openAIError := dto.OpenAIError{ openAIError := dto.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId), Message: fmt.Sprintf("The model '%s' does not exist", modelId),

View File

@ -14,7 +14,7 @@ func GetOptions(c *gin.Context) {
var options []*model.Option var options []*model.Option
common.OptionMapRWMutex.Lock() common.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap { for k, v := range common.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") { if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue continue
} }
options = append(options, &model.Option{ options = append(options, &model.Option{
@ -50,14 +50,6 @@ func UpdateOption(c *gin.Context) {
}) })
return return
} }
case "LinuxDoOAuthEnabled":
if option.Value == "true" && common.LinuxDoClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 LINUX DO OAuth请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret",
})
return
}
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@ -1,40 +0,0 @@
package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/model"
)
func GetPricing(c *gin.Context) {
pricing := model.GetPricing()
c.JSON(200, gin.H{
"success": true,
"data": pricing,
"group_ratio": common.GroupRatio,
})
}
func ResetModelRatio(c *gin.Context) {
defaultStr := common.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
err = common.UpdateModelRatioByJSONString(defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "重置模型倍率成功",
})
}

View File

@ -1,223 +1,61 @@
package controller package controller
import ( import (
"bytes"
"errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io"
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay" "one-api/relay"
"one-api/relay/constant" "one-api/relay/constant"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strings" "strconv"
) )
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
var err *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case relayconstant.RelayModeImagesGenerations: case relayconstant.RelayModeImagesGenerations:
err = relay.ImageHelper(c, relayMode) err = relay.RelayImageHelper(c, relayMode)
case relayconstant.RelayModeAudioSpeech: case relayconstant.RelayModeAudioSpeech:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranslation: case relayconstant.RelayModeAudioTranslation:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranscription: case relayconstant.RelayModeAudioTranscription:
err = relay.AudioHelper(c) err = relay.AudioHelper(c, relayMode)
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode)
default: default:
err = relay.TextHelper(c) err = relay.TextHelper(c)
} }
return err if err != nil {
} requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry")
func Playground(c *gin.Context) { retryTimes, _ := strconv.Atoi(retryTimesStr)
var openaiErr *dto.OpenAIErrorWithStatusCode if retryTimesStr == "" {
retryTimes = common.RetryTimes
defer func() { }
if openaiErr != nil { if retryTimes > 0 {
c.JSON(openaiErr.StatusCode, gin.H{ c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
"error": openaiErr.Error, } else {
if err.StatusCode == http.StatusTooManyRequests {
//err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
c.JSON(err.StatusCode, gin.H{
"error": err.Error,
}) })
} }
}() channelId := c.GetInt("channel_id")
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
return
}
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
return
}
if playgroundRequest.Model == "" {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
return
}
c.Set("original_model", playgroundRequest.Model)
group := playgroundRequest.Group
userGroup := c.GetString("group")
if group == "" {
group = userGroup
} else {
if !common.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
}
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
Relay(c)
}
func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
openaiErr = relayRequest(c, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
if retryCount == 0 {
autoBan := c.GetBool("auto_ban") autoBan := c.GetBool("auto_ban")
autoBanInt := 1 common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
if !autoBan { // https://platform.openai.com/docs/guides/error-codes/api-errors
autoBanInt = 0 if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message)
} }
return &model.Channel{
Id: c.GetInt("channel_id"),
Type: c.GetInt("channel_type"),
Name: c.GetString("channel_name"),
AutoBan: &autoBanInt,
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
if openaiErr == nil {
return false
}
if openaiErr.LocalError {
return false
}
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if openaiErr.StatusCode == http.StatusTooManyRequests {
return true
}
if openaiErr.StatusCode == 307 {
return true
}
if openaiErr.StatusCode/100 == 5 {
// 超时不重试
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
return false
}
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
return true
}
return false
}
if openaiErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
if openaiErr.StatusCode/100 == 2 {
return false
}
return true
}
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(channelType, err) && autoBan {
service.DisableChannel(channelId, channelName, err.Error.Message)
} }
} }
@ -250,7 +88,7 @@ func RelayMidjourney(c *gin.Context) {
"code": err.Code, "code": err.Code,
}) })
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
} }
} }
@ -277,94 +115,3 @@ func RelayNotFound(c *gin.Context) {
"error": err, "error": err,
}) })
} }
func RelayTask(c *gin.Context) {
retryTimes := common.RetryTimes
channelId := c.GetInt("channel_id")
relayMode := c.GetInt("relay_mode")
group := c.GetString("group")
originalModel := c.GetString("original_model")
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
taskErr := taskRelayHandler(c, relayMode)
if taskErr == nil {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
taskErr = taskRelayHandler(c, relayMode)
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
}
c.JSON(taskErr.StatusCode, taskErr)
}
}
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)
}
return err
}
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
if taskErr == nil {
return false
}
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if taskErr.StatusCode == http.StatusTooManyRequests {
return true
}
if taskErr.StatusCode == 307 {
return true
}
if taskErr.StatusCode/100 == 5 {
// 超时不重试
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
return false
}
return true
}
if taskErr.StatusCode == http.StatusBadRequest {
return false
}
if taskErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
if taskErr.LocalError {
return false
}
if taskErr.StatusCode/100 == 2 {
return false
}
return true
}

View File

@ -1,99 +0,0 @@
package controller
import (
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v79"
"github.com/stripe/stripe-go/v79/webhook"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
func StripeWebhook(c *gin.Context) {
payload, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
c.AbortWithStatus(http.StatusServiceUnavailable)
return
}
signature := c.GetHeader("Stripe-Signature")
endpointSecret := common.StripeWebhookSecret
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
IgnoreAPIVersionMismatch: true,
})
if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err)
c.AbortWithStatus(http.StatusBadRequest)
return
}
switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(event)
case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(event)
default:
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
}
c.Status(http.StatusOK)
}
func sessionCompleted(event stripe.Event) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "complete" != status {
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
return
}
err := model.Recharge(referenceId, customerId)
if err != nil {
log.Println(err.Error(), referenceId)
return
}
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency"))
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
}
func sessionExpired(event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "expired" != status {
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
return
}
if "" == referenceId {
log.Println("未提供支付单号")
return
}
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Println("充值订单不存在", referenceId)
return
}
if topUp.Status != common.TopUpStatusPending {
log.Println("充值订单状态错误", referenceId)
}
topUp.Status = common.TopUpStatusExpired
err := topUp.Update()
if err != nil {
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
return
}
log.Println("充值订单已过期", referenceId)
}

View File

@ -1,284 +0,0 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay"
"sort"
"strconv"
"time"
)
func UpdateTaskBulk() {
//revocer
//imageModel := "midjourney"
for {
time.Sleep(time.Duration(15) * time.Second)
common.SysLog("任务进度轮询开始")
ctx := context.TODO()
allTasks := model.GetAllUnFinishSyncTasks(500)
platformTask := make(map[constant.TaskPlatform][]*model.Task)
for _, t := range allTasks {
platformTask[t.Platform] = append(platformTask[t.Platform], t)
}
for platform, tasks := range platformTask {
if len(tasks) == 0 {
continue
}
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Task)
nullTaskIds := make([]int64, 0)
for _, task := range tasks {
if task.TaskID == "" {
// 统计失败的未完成任务
nullTaskIds = append(nullTaskIds, task.ID)
continue
}
taskM[task.TaskID] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
}
if len(nullTaskIds) > 0 {
err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
} else {
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
continue
}
UpdateTaskByPlatform(platform, taskChannelM, taskM)
}
common.SysLog("任务进度轮询完成")
}
}
func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
switch platform {
case constant.TaskPlatformMidjourney:
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
default:
common.SysLog("未知平台")
}
}
func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
if err != nil {
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
}
}
return nil
}
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
channel, err := model.CacheGetChannel(channelId)
if err != nil {
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
err = model.TaskBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
}
return err
}
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
if adaptor == nil {
return errors.New("adaptor not found")
}
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
"ids": taskIds,
})
if err != nil {
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
return err
}
if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
return err
}
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
return err
}
if !responseItems.IsSuccess() {
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
return err
}
for _, responseItem := range responseItems.Data {
task := taskM[responseItem.TaskID]
if !checkTaskNeedUpdate(task, responseItem) {
continue
}
task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("异步任务执行失败 %s补偿 %s", task.TaskID, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
if responseItem.Status == model.TaskStatusSuccess {
task.Progress = "100%"
}
task.Data = responseItem.Data
err = task.Update()
if err != nil {
common.SysError("UpdateMidjourneyTask task error: " + err.Error())
}
}
return nil
}
func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
if oldTask.SubmitTime != newTask.SubmitTime {
return true
}
if oldTask.StartTime != newTask.StartTime {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if string(oldTask.Status) != newTask.Status {
return true
}
if oldTask.FailReason != newTask.FailReason {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
return true
}
oldData, _ := json.Marshal(oldTask.Data)
newData, _ := json.Marshal(newTask.Data)
sort.Slice(oldData, func(i, j int) bool {
return oldData[i] < oldData[j]
})
sort.Slice(newData, func(i, j int) bool {
return newData[i] < newData[j]
})
if string(oldData) != string(newData) {
return true
}
return false
}
func GetAllTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
// 解析其他查询参数
queryParams := model.SyncTaskQueryParams{
Platform: constant.TaskPlatform(c.Query("platform")),
TaskID: c.Query("task_id"),
Status: c.Query("status"),
Action: c.Query("action"),
StartTimestamp: startTimestamp,
EndTimestamp: endTimestamp,
}
logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Task, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func GetUserTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
userId := c.GetInt("id")
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
queryParams := model.SyncTaskQueryParams{
Platform: constant.TaskPlatform(c.Query("platform")),
TaskID: c.Query("task_id"),
Status: c.Query("status"),
Action: c.Query("action"),
StartTimestamp: startTimestamp,
EndTimestamp: endTimestamp,
}
logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Task, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}

View File

@ -5,7 +5,6 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"io" "io"
"net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"sort" "sort"
@ -49,13 +48,6 @@ func TelegramBind(c *gin.Context) {
}) })
return return
} }
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
user.TelegramId = telegramId user.TelegramId = telegramId
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{

View File

@ -123,19 +123,10 @@ func AddToken(c *gin.Context) {
}) })
return return
} }
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Key: key, Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(), CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(), AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
@ -143,8 +134,6 @@ func AddToken(c *gin.Context) {
UnlimitedQuota: token.UnlimitedQuota, UnlimitedQuota: token.UnlimitedQuota,
ModelLimitsEnabled: token.ModelLimitsEnabled, ModelLimitsEnabled: token.ModelLimitsEnabled,
ModelLimits: token.ModelLimits, ModelLimits: token.ModelLimits,
AllowIps: token.AllowIps,
Group: token.Group,
} }
err = cleanToken.Insert() err = cleanToken.Insert()
if err != nil { if err != nil {
@ -232,8 +221,6 @@ func UpdateToken(c *gin.Context) {
cleanToken.UnlimitedQuota = token.UnlimitedQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
cleanToken.ModelLimits = token.ModelLimits cleanToken.ModelLimits = token.ModelLimits
cleanToken.AllowIps = token.AllowIps
cleanToken.Group = token.Group
} }
err = cleanToken.Update() err = cleanToken.Update()
if err != nil { if err != nil {

View File

@ -1,20 +1,21 @@
package controller package controller
import "C"
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v79" "github.com/samber/lo"
"github.com/stripe/stripe-go/v79/checkout/session" epay "github.com/star-horizon/go-epay"
"log" "log"
"net/url"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/service"
"strconv" "strconv"
"strings" "sync"
"time" "time"
) )
type PayRequest struct { type EpayRequest struct {
Amount int `json:"amount"` Amount int `json:"amount"`
PaymentMethod string `json:"payment_method"` PaymentMethod string `json:"payment_method"`
TopUpCode string `json:"top_up_code"` TopUpCode string `json:"top_up_code"`
@ -25,114 +26,177 @@ type AmountRequest struct {
TopUpCode string `json:"top_up_code"` TopUpCode string `json:"top_up_code"`
} }
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) { func GetEpayClient() *epay.Client {
if !strings.HasPrefix(common.StripeApiSecret, "sk_") { if common.PayAddress == "" || common.EpayId == "" || common.EpayKey == "" {
return "", fmt.Errorf("无效的Stripe API密钥") return nil
} }
withUrl, err := epay.NewClientWithUrl(&epay.Config{
stripe.Key = common.StripeApiSecret PartnerID: common.EpayId,
Key: common.EpayKey,
params := &stripe.CheckoutSessionParams{ }, common.PayAddress)
ClientReferenceID: stripe.String(referenceId),
SuccessURL: stripe.String(common.ServerAddress + "/log"),
CancelURL: stripe.String(common.ServerAddress + "/topup"),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(common.StripePriceId),
Quantity: stripe.Int64(amount),
},
},
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
}
if "" == customerId {
if "" != email {
params.CustomerEmail = stripe.String(email)
}
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
} else {
params.Customer = stripe.String(customerId)
}
result, err := session.New(params)
if err != nil { if err != nil {
return "", err return nil
} }
return withUrl
return result.URL, nil
} }
func GetPayAmount(count float64) float64 { func GetAmount(count float64, user model.User) float64 {
return count * common.StripeUnitPrice // 别问为什么用float64问就是这么点钱没必要
} topupGroupRatio := common.GetTopupGroupRatio(user.Group)
if topupGroupRatio == 0 {
func GetChargedAmount(count float64, user model.User) float64 { topupGroupRatio = 1
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
if topUpGroupRatio == 0 {
topUpGroupRatio = 1
} }
amount := count * common.Price * topupGroupRatio
return count * topUpGroupRatio return amount
} }
func RequestPayLink(c *gin.Context) { func RequestEpay(c *gin.Context) {
var req PayRequest var req EpayRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": err.Error(), "data": 10}) c.JSON(200, gin.H{"message": err.Error(), "data": 10})
return return
} }
if !common.PaymentEnabled {
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
return
}
if req.PaymentMethod != "stripe" {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.Amount < common.MinTopUp { if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10}) c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
return return
} }
if req.Amount > 10000 {
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
return
}
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
chargedMoney := GetChargedAmount(float64(req.Amount), *user) payMoney := GetAmount(float64(req.Amount), *user)
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), common.RandomString(4)) var payType epay.PurchaseType
referenceId := "ref_" + common.Sha1(reference) if req.PaymentMethod == "zfb" {
payType = epay.Alipay
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, int64(req.Amount)) }
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = epay.WechatPay
}
callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(common.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := strconv.FormatInt(time.Now().Unix(), 10)
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType,
ServiceTradeNo: "A" + tradeNo,
Name: "B" + tradeNo,
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
Device: epay.PC,
NotifyUrl: notifyUrl,
ReturnUrl: returnUrl,
})
if err != nil { if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: req.Amount, Amount: req.Amount,
Money: chargedMoney, Money: payMoney,
TradeNo: referenceId, TradeNo: "A" + tradeNo,
CreateTime: time.Now().Unix(), CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending, Status: "pending",
} }
err = topUp.Insert() err = topUp.Insert()
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
return return
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
"message": "success", }
"data": gin.H{
"payLink": payLink, // tradeNo lock
}, var orderLocks sync.Map
}) var createLock sync.Mutex
// LockOrder 尝试对给定订单号加锁
func LockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if !ok {
createLock.Lock()
defer createLock.Unlock()
lock, ok = orderLocks.Load(tradeNo)
if !ok {
lock = new(sync.Mutex)
orderLocks.Store(tradeNo, lock)
}
}
lock.(*sync.Mutex).Lock()
}
// UnlockOrder 释放给定订单号的锁
func UnlockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if ok {
lock.(*sync.Mutex).Unlock()
}
}
func EpayNotify(c *gin.Context) {
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
client := GetEpayClient()
if client == nil {
log.Println("易支付回调失败 未找到配置信息")
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
return
}
}
verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus {
_, err := c.Writer.Write([]byte("success"))
if err != nil {
log.Println("易支付回调写入失败")
}
} else {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
}
log.Println("易支付回调签名验证失败")
return
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp == nil {
log.Printf("易支付回调未找到订单: %v", verifyInfo)
return
}
if topUp.Status == "pending" {
topUp.Status = "success"
err := topUp.Update()
if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp)
return
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*500000), topUp.Money))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
}
} }
func RequestAmount(c *gin.Context) { func RequestAmount(c *gin.Context) {
@ -142,23 +206,12 @@ func RequestAmount(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return return
} }
if !common.PaymentEnabled {
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
return
}
if req.Amount < common.MinTopUp { if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)}) c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
payMoney := GetPayAmount(float64(req.Amount)) payMoney := GetAmount(float64(req.Amount), *user)
chargedMoney := GetChargedAmount(float64(req.Amount), *user) c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
c.JSON(200, gin.H{
"message": "success",
"data": gin.H{
"payAmount": strconv.FormatFloat(payMoney, 'f', 2, 64),
"chargedAmount": strconv.FormatFloat(chargedMoney, 'f', 2, 64),
},
})
} }

View File

@ -7,12 +7,9 @@ import (
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings"
"sync"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/constant"
) )
type LoginRequest struct { type LoginRequest struct {
@ -68,8 +65,6 @@ func setupLogin(user *model.User, c *gin.Context) {
session.Set("username", user.Username) session.Set("username", user.Username)
session.Set("role", user.Role) session.Set("role", user.Role)
session.Set("status", user.Status) session.Set("status", user.Status)
session.Set("group", user.Group)
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
err := session.Save() err := session.Save()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -161,9 +156,8 @@ func Register(c *gin.Context) {
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "数据库错误,请稍后重试", "message": err.Error(),
}) })
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return return
} }
if exist { if exist {
@ -191,48 +185,6 @@ func Register(c *gin.Context) {
}) })
return return
} }
// 获取插入后的用户ID
var insertedUser model.User
if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户注册失败或用户ID获取失败",
})
return
}
// 生成默认令牌
if constant.GenerateDefaultToken {
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成默认令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
// 生成默认令牌
token := model.Token{
UserId: insertedUser.Id, // 使用插入后的用户ID
Name: cleanUser.Username + "的初始令牌",
Key: key,
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: -1, // 永不过期
RemainQuota: 500000, // 示例额度
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "创建默认令牌失败",
})
return
}
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@ -263,8 +215,7 @@ func GetAllUsers(c *gin.Context) {
func SearchUsers(c *gin.Context) { func SearchUsers(c *gin.Context) {
keyword := c.Query("keyword") keyword := c.Query("keyword")
group := c.Query("group") users, err := model.SearchUsers(keyword)
users, err := model.SearchUsers(keyword, group)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -323,18 +274,7 @@ func GenerateAccessToken(c *gin.Context) {
}) })
return return
} }
// get rand int 28-32 user.AccessToken = common.GetUUID()
randI := common.GetRandomInt(4)
key, err := common.GenerateRandomKey(29 + randI)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成失败",
})
common.SysError("failed to generate key: " + err.Error())
return
}
user.SetAccessToken(key)
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -511,7 +451,7 @@ func UpdateUser(c *gin.Context) {
updatedUser.Password = "" // rollback to what it should be updatedUser.Password = "" // rollback to what it should be
} }
updatePassword := updatedUser.Password != "" updatePassword := updatedUser.Password != ""
if err := updatedUser.Edit(updatePassword); err != nil { if err := updatedUser.Update(updatePassword); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
@ -575,7 +515,7 @@ func UpdateSelf(c *gin.Context) {
return return
} }
func HardDeleteUser(c *gin.Context) { func DeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -584,7 +524,7 @@ func HardDeleteUser(c *gin.Context) {
}) })
return return
} }
originUser, err := model.GetUserByIdUnscoped(id, false) originUser, err := model.GetUserById(id, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -608,23 +548,9 @@ func HardDeleteUser(c *gin.Context) {
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
} }
func DeleteSelf(c *gin.Context) { func DeleteSelf(c *gin.Context) {
if !common.UserSelfDeletionEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "当前设置不允许用户自我删除账号",
})
return
}
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
@ -654,7 +580,6 @@ func DeleteSelf(c *gin.Context) {
func CreateUser(c *gin.Context) { func CreateUser(c *gin.Context) {
var user model.User var user model.User
err := json.NewDecoder(c.Request.Body).Decode(&user) err := json.NewDecoder(c.Request.Body).Decode(&user)
user.Username = strings.TrimSpace(user.Username)
if err != nil || user.Username == "" || user.Password == "" { if err != nil || user.Username == "" || user.Password == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -702,8 +627,8 @@ func CreateUser(c *gin.Context) {
} }
type ManageRequest struct { type ManageRequest struct {
Id int `json:"id"` Username string `json:"username"`
Action string `json:"action"` Action string `json:"action"`
} }
// ManageUser Only admin user can do this // ManageUser Only admin user can do this
@ -719,7 +644,7 @@ func ManageUser(c *gin.Context) {
return return
} }
user := model.User{ user := model.User{
Id: req.Id, Username: req.Username,
} }
// Fill attributes // Fill attributes
model.DB.Unscoped().Where(&user).First(&user) model.DB.Unscoped().Where(&user).First(&user)
@ -864,11 +789,7 @@ type topUpRequest struct {
Key string `json:"key"` Key string `json:"key"`
} }
var topUpLock = sync.Mutex{}
func TopUp(c *gin.Context) { func TopUp(c *gin.Context) {
topUpLock.Lock()
defer topUpLock.Unlock()
req := topUpRequest{} req := topUpRequest{}
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {

View File

@ -78,13 +78,6 @@ func WeChatAuth(c *gin.Context) {
}) })
return return
} }
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else { } else {
if common.RegisterEnabled { if common.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)

View File

@ -2,17 +2,18 @@ version: '3.4'
services: services:
new-api: new-api:
image: pengzhile/new-api:latest image: calciumion/new-api:latest
# build: .
container_name: new-api container_name: new-api
restart: always restart: always
command: --log-dir /app/logs command: --log-dir /app/logs
ports: ports:
- "3000:3000" - "3000:3000"
volumes: volumes:
- ./data/new-api:/data - ./data:/data
- ./logs:/app/logs - ./logs:/app/logs
environment: environment:
- SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库 - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串 - SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
@ -22,22 +23,13 @@ services:
depends_on: depends_on:
- redis - redis
- db healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
timeout: 10s
retries: 3
redis: redis:
image: redis:7.4 image: redis:latest
container_name: redis container_name: redis
restart: always restart: always
db:
image: mysql:8.2
container_name: mysql
restart: always
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: newapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: new-api # 自动创建数据库

View File

@ -1,34 +1,13 @@
package dto package dto
type AudioRequest struct { type TextToSpeechRequest struct {
Model string `json:"model"` Model string `json:"model" binding:"required"`
Input string `json:"input"` Input string `json:"input" binding:"required"`
Voice string `json:"voice"` Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed,omitempty"` Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format"`
} }
type AudioResponse struct { type AudioResponse struct {
Text string `json:"text"` Text string `json:"text"`
} }
type WhisperVerboseJSONResponse struct {
Task string `json:"task,omitempty"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
}
type Segment struct {
Id int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogprob float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
}

View File

@ -12,11 +12,9 @@ type ImageRequest struct {
} }
type ImageResponse struct { type ImageResponse struct {
Data []ImageData `json:"data"` Created int `json:"created"`
Created int64 `json:"created"` Data []struct {
} Url string `json:"url"`
type ImageData struct { B64Json string `json:"b64_json"`
Url string `json:"url"` }
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
} }

View File

@ -10,7 +10,6 @@ type OpenAIError struct {
type OpenAIErrorWithStatusCode struct { type OpenAIErrorWithStatusCode struct {
Error OpenAIError `json:"error"` Error OpenAIError `json:"error"`
StatusCode int `json:"status_code"` StatusCode int `json:"status_code"`
LocalError bool
} }
type GeneralErrorResponse struct { type GeneralErrorResponse struct {

View File

@ -33,12 +33,6 @@ type MidjourneyResponse struct {
Result string `json:"result"` Result string `json:"result"`
} }
type MidjourneyUploadResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result []string `json:"result"`
}
type MidjourneyResponseWithStatusCode struct { type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"` StatusCode int `json:"statusCode"`
Response MidjourneyResponse Response MidjourneyResponse

View File

@ -1,6 +0,0 @@
package dto
type PlayGroundRequest struct {
Model string `json:"model,omitempty"`
Group string `json:"group,omitempty"`
}

View File

@ -1,26 +0,0 @@
package dto
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}

View File

@ -1,22 +0,0 @@
package dto
type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`
}
type RerankResponseDocument struct {
Document any `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type RerankResponse struct {
Results []RerankResponseDocument `json:"results"`
Usage Usage `json:"usage"`
}

View File

@ -1,129 +0,0 @@
package dto
import (
"encoding/json"
)
type TaskData interface {
SunoDataResponse | []SunoDataResponse | string | any
}
type SunoSubmitReq struct {
GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"`
Prompt string `json:"prompt,omitempty"`
Mv string `json:"mv,omitempty"`
Title string `json:"title,omitempty"`
Tags string `json:"tags,omitempty"`
ContinueAt float64 `json:"continue_at,omitempty"`
TaskID string `json:"task_id,omitempty"`
ContinueClipId string `json:"continue_clip_id,omitempty"`
MakeInstrumental bool `json:"make_instrumental"`
}
type FetchReq struct {
IDs []string `json:"ids"`
}
type SunoDataResponse struct {
TaskID string `json:"task_id" gorm:"type:varchar(50);index"`
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
Status string `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed
FailReason string `json:"fail_reason"`
SubmitTime int64 `json:"submit_time" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
Data json.RawMessage `json:"data" gorm:"type:json"`
}
type SunoSong struct {
ID string `json:"id"`
VideoURL string `json:"video_url"`
AudioURL string `json:"audio_url"`
ImageURL string `json:"image_url"`
ImageLargeURL string `json:"image_large_url"`
MajorModelVersion string `json:"major_model_version"`
ModelName string `json:"model_name"`
Status string `json:"status"`
Title string `json:"title"`
Text string `json:"text"`
Metadata SunoMetadata `json:"metadata"`
}
type SunoMetadata struct {
Tags string `json:"tags"`
Prompt string `json:"prompt"`
GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"`
AudioPromptID interface{} `json:"audio_prompt_id"`
Duration interface{} `json:"duration"`
ErrorType interface{} `json:"error_type"`
ErrorMessage interface{} `json:"error_message"`
}
type SunoLyrics struct {
ID string `json:"id"`
Status string `json:"status"`
Title string `json:"title"`
Text string `json:"text"`
}
const TaskSuccessCode = "success"
type TaskResponse[T TaskData] struct {
Code string `json:"code"`
Message string `json:"message"`
Data T `json:"data"`
}
func (t *TaskResponse[T]) IsSuccess() bool {
return t.Code == TaskSuccessCode
}
type TaskDto struct {
TaskID string `json:"task_id"` // 第三方id不一定有/ song id\ Task id
Action string `json:"action"` // 任务类型, song, lyrics, description-mode
Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed
FailReason string `json:"fail_reason"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
Progress string `json:"progress"`
Data json.RawMessage `json:"data"`
}
type SunoGoAPISubmitReq struct {
CustomMode bool `json:"custom_mode"`
Input SunoGoAPISubmitReqInput `json:"input"`
NotifyHook string `json:"notify_hook,omitempty"`
}
type SunoGoAPISubmitReqInput struct {
GptDescriptionPrompt string `json:"gpt_description_prompt"`
Prompt string `json:"prompt"`
Mv string `json:"mv"`
Title string `json:"title"`
Tags string `json:"tags"`
ContinueAt float64 `json:"continue_at"`
TaskID string `json:"task_id"`
ContinueClipId string `json:"continue_clip_id"`
MakeInstrumental bool `json:"make_instrumental"`
}
type GoAPITaskResponse[T any] struct {
Code int `json:"code"`
Message string `json:"message"`
Data T `json:"data"`
ErrorMessage string `json:"error_message,omitempty"`
}
type GoAPITaskResponseData struct {
TaskID string `json:"task_id"`
}
type GoAPIFetchResponseData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
Input string `json:"input"`
Clips map[string]SunoSong `json:"clips"`
}

View File

@ -1,10 +0,0 @@
package dto
type TaskError struct {
Code string `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
StatusCode int `json:"-"`
LocalError bool `json:"-"`
Error error `json:"-"`
}

View File

@ -2,66 +2,34 @@ package dto
import "encoding/json" import "encoding/json"
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` Messages []Message `json:"messages,omitempty"`
Messages []Message `json:"messages,omitempty"` Prompt any `json:"prompt,omitempty"`
Prompt any `json:"prompt,omitempty"` Stream bool `json:"stream,omitempty"`
BestOf int `json:"best_of,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
Echo bool `json:"echo,omitempty"` Temperature float64 `json:"temperature,omitempty"`
Stream bool `json:"stream,omitempty"` TopP float64 `json:"top_p,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` TopK int `json:"top_k,omitempty"`
Suffix string `json:"suffix,omitempty"` Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` N int `json:"n,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` Input any `json:"input,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Instruction string `json:"instruction,omitempty"`
TopP float64 `json:"top_p,omitempty"` Size string `json:"size,omitempty"`
TopK int `json:"top_k,omitempty"` Functions any `json:"functions,omitempty"`
Stop any `json:"stop,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
N int `json:"n,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
Input any `json:"input,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Instruction string `json:"instruction,omitempty"` Seed float64 `json:"seed,omitempty"`
Size string `json:"size,omitempty"` Tools any `json:"tools,omitempty"`
Functions any `json:"functions,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` User string `json:"user,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` LogProbs bool `json:"logprobs,omitempty"`
ResponseFormat any `json:"response_format,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogitBias any `json:"logit_bias,omitempty"`
LogProbs any `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"`
}
type OpenAITools struct {
Type string `json:"type"`
Function OpenAIFunction `json:"function"`
}
type OpenAIFunction struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"`
}
type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int {
return int(r.MaxTokens)
} }
func (r GeneralOpenAIRequest) ParseInput() []string { func (r GeneralOpenAIRequest) ParseInput() []string {
@ -84,12 +52,11 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
} }
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content json.RawMessage `json:"content"` Content json.RawMessage `json:"content"`
ReasoningContent *string `json:"reasoning_content,omitempty"` Name *string `json:"name,omitempty"`
Name *string `json:"name,omitempty"` ToolCalls any `json:"tool_calls,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"` ToolCallId string `json:"tool_call_id,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
} }
type MediaMessage struct { type MediaMessage struct {
@ -116,11 +83,6 @@ func (m Message) StringContent() string {
return string(m.Content) return string(m.Content)
} }
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
}
func (m Message) IsStringContent() bool { func (m Message) IsStringContent() bool {
var stringContent string var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil { if err := json.Unmarshal(m.Content, &stringContent); err == nil {
@ -160,7 +122,7 @@ func (m Message) ParseContent() []MediaMessage {
if ok { if ok {
subObj["detail"] = detail.(string) subObj["detail"] = detail.(string)
} else { } else {
subObj["detail"] = "high" subObj["detail"] = "auto"
} }
contentList = append(contentList, MediaMessage{ contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL, Type: ContentTypeImageURL,
@ -169,16 +131,7 @@ func (m Message) ParseContent() []MediaMessage {
Detail: subObj["detail"].(string), Detail: subObj["detail"].(string),
}, },
}) })
} else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: "high",
},
})
} }
} }
} }
return contentList return contentList

View File

@ -1,29 +1,9 @@
package dto package dto
type TextResponseWithError struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type SimpleResponse struct {
Usage `json:"usage"`
Error OpenAIError `json:"error"`
Choices []OpenAITextResponseChoice `json:"choices"`
}
type TextResponse struct { type TextResponse struct {
Id string `json:"id"` Choices []*OpenAITextResponseChoice `json:"choices"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"` Usage `json:"usage"`
Error *OpenAIError `json:"error,omitempty"`
} }
type OpenAITextResponseChoice struct { type OpenAITextResponseChoice struct {
@ -34,7 +14,6 @@ type OpenAITextResponseChoice struct {
type OpenAITextResponse struct { type OpenAITextResponse struct {
Id string `json:"id"` Id string `json:"id"`
Model string `json:"model"`
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"` Choices []OpenAITextResponseChoice `json:"choices"`
@ -42,9 +21,9 @@ type OpenAITextResponse struct {
} }
type OpenAIEmbeddingResponseItem struct { type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"` Object string `json:"object"`
Index int `json:"index"` Index int `json:"index"`
Embedding any `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
type OpenAIEmbeddingResponse struct { type OpenAIEmbeddingResponse struct {
@ -55,70 +34,25 @@ type OpenAIEmbeddingResponse struct {
} }
type ChatCompletionsStreamResponseChoice struct { type ChatCompletionsStreamResponseChoice struct {
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"` Delta struct {
Logprobs *any `json:"logprobs"` Content string `json:"content"`
FinishReason *string `json:"finish_reason"` Role string `json:"role,omitempty"`
Index int `json:"index"` ToolCalls any `json:"tool_calls,omitempty"`
} } `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
type ChatCompletionsStreamResponseChoiceDelta struct { Index int `json:"index,omitempty"`
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
c.Content = &s
}
func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
if c.Content == nil {
return ""
}
return *c.Content
}
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id"`
Type any `json:"type"`
Function FunctionCall `json:"function"`
}
type FunctionCall struct {
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
// call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments,omitempty"`
} }
type ChatCompletionsStreamResponse struct { type ChatCompletionsStreamResponse struct {
Id string `json:"id"` Id string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created int64 `json:"created"`
Model string `json:"model"` Model string `json:"model"`
SystemFingerprint *string `json:"system_fingerprint"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil {
return ""
}
return *c.SystemFingerprint
}
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
c.SystemFingerprint = &s
} }
type ChatCompletionsStreamResponseSimple struct { type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
} }
type CompletionsStreamResponse struct { type CompletionsStreamResponse struct {

58
go.mod
View File

@ -1,15 +1,11 @@
module one-api module one-api
// +heroku goVersion go1.18 // +heroku goVersion go1.18
go 1.21 go 1.18
require ( require (
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/gin-contrib/cors v1.4.0
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/gin-contrib/cors v1.6.0
github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1 github.com/gin-contrib/static v0.0.1
@ -17,17 +13,14 @@ require (
github.com/go-playground/validator/v10 v10.19.0 github.com/go-playground/validator/v10 v10.19.0
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.6.0 github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.4.0 github.com/pkoukk/tiktoken-go v0.1.6
github.com/pkg/errors v0.9.1 github.com/samber/lo v1.38.1
github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible github.com/shirou/gopsutil v3.21.11+incompatible
github.com/stripe/stripe-go/v79 v79.12.0 github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.31.0 golang.org/x/crypto v0.21.0
golang.org/x/image v0.18.0 golang.org/x/image v0.15.0
golang.org/x/net v0.33.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2 gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3 gorm.io/driver/sqlite v1.4.3
@ -36,16 +29,11 @@ require (
require ( require (
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/sonic v1.11.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect
github.com/chenzhuoyu/iasm v0.9.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-ole/go-ole v1.2.6 // indirect
@ -53,7 +41,6 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect
@ -64,24 +51,25 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.1.1 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/stretchr/testify v1.9.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.7.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/sync v0.10.0 // indirect golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.28.0 // indirect golang.org/x/sync v0.1.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/sys v0.18.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

132
go.sum
View File

@ -2,49 +2,27 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.11.2 h1:ywfwo0a/3j9HR8wsYGWsIWl2mvRsI950HyoxiBERw5A= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/bytedance/sonic v1.11.2/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA=
github.com/chenzhuoyu/iasm v0.9.0/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog=
github.com/chenzhuoyu/iasm v0.9.1 h1:tUHQJXo3NhBqw6s33wkGn9SP3bvrWLdlVIJ3hQBL7P0=
github.com/chenzhuoyu/iasm v0.9.1/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.6.0 h1:0Z7D/bVhE6ja07lI8CTjTonp6SB07o8bNuFyRbsBUQg= github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
github.com/gin-contrib/cors v1.6.0/go.mod h1:cI+h6iOAyxKRtUtC6iF/Si1KSFvGm/gK+kshxlCi8ro= github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs=
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4= github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk= github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk=
github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE= github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE=
@ -61,7 +39,6 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@ -85,12 +62,11 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
@ -107,8 +83,6 @@ github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@ -118,9 +92,8 @@ github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
@ -140,6 +113,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -147,28 +122,25 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2/go.mod h1:SiffGCWGGMVwujne2dUQbJ5zUVD1V1Yj0hDuTfqFNEo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@ -179,11 +151,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stripe/stripe-go/v79 v79.12.0 h1:HQs/kxNEB3gYA7FnkSFkp0kSOeez0fsmCWev6SxftYs=
github.com/stripe/stripe-go/v79 v79.12.0/go.mod h1:cuH6X0zC8peY6f1AubHwgJ/fJSn2dh5pfiCr6CjyKVU=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@ -194,60 +163,56 @@ github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVM
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
@ -266,5 +231,4 @@ gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

24
main.go
View File

@ -3,14 +3,12 @@ package main
import ( import (
"embed" "embed"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie" "github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/controller" "one-api/controller"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
@ -22,10 +20,10 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
) )
//go:embed web/dist //go:embed web/build
var buildFS embed.FS var buildFS embed.FS
//go:embed web/dist/index.html //go:embed web/build/index.html
var indexPage []byte var indexPage []byte
func main() { func main() {
@ -42,11 +40,6 @@ func main() {
if err != nil { if err != nil {
common.FatalLog("failed to initialize database: " + err.Error()) common.FatalLog("failed to initialize database: " + err.Error())
} }
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
defer func() { defer func() {
err := model.CloseDB() err := model.CloseDB()
if err != nil { if err != nil {
@ -60,8 +53,6 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error()) common.FatalLog("failed to initialize Redis: " + err.Error())
} }
// Initialize constants
constant.InitEnv()
// Initialize options // Initialize options
model.InitOptionMap() model.InitOptionMap()
if common.RedisEnabled { if common.RedisEnabled {
@ -98,14 +89,9 @@ func main() {
} }
go controller.AutomaticallyTestChannels(frequency) go controller.AutomaticallyTestChannels(frequency)
} }
if common.IsMasterNode && constant.UpdateTask { common.SafeGoroutine(func() {
gopool.Go(func() { controller.UpdateMidjourneyTaskBulk()
controller.UpdateMidjourneyTaskBulk() })
})
gopool.Go(func() {
controller.UpdateTaskBulk()
})
}
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")

View File

@ -7,7 +7,7 @@ all: build-frontend start-backend
build-frontend: build-frontend:
@echo "Building frontend..." @echo "Building frontend..."
@cd $(FRONTEND_DIR) && yarn install --network-timeout 1000000 && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) yarn build @cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build npm run build
start-backend: start-backend:
@echo "Starting backend dev server..." @echo "Starting backend dev server..."

View File

@ -6,29 +6,15 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv"
"strings" "strings"
) )
func validUserInfo(username string, role int) bool {
// check username is empty
if strings.TrimSpace(username) == "" {
return false
}
if !common.IsValidateRole(role) {
return false
}
return true
}
func authHelper(c *gin.Context, minRole int) { func authHelper(c *gin.Context, minRole int) {
session := sessions.Default(c) session := sessions.Default(c)
username := session.Get("username") username := session.Get("username")
role := session.Get("role") role := session.Get("role")
id := session.Get("id") id := session.Get("id")
status := session.Get("status") status := session.Get("status")
linuxDoEnable := session.Get("linuxdo_enable")
useAccessToken := false
if username == nil { if username == nil {
// Check access token // Check access token
accessToken := c.Request.Header.Get("Authorization") accessToken := c.Request.Header.Get("Authorization")
@ -42,21 +28,11 @@ func authHelper(c *gin.Context, minRole int) {
} }
user := model.ValidateAccessToken(accessToken) user := model.ValidateAccessToken(accessToken)
if user != nil && user.Username != "" { if user != nil && user.Username != "" {
if !validUserInfo(user.Username, user.Role) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
// Token is valid // Token is valid
username = user.Username username = user.Username
role = user.Role role = user.Role
id = user.Id id = user.Id
status = user.Status status = user.Status
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
useAccessToken = true
} else { } else {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -66,36 +42,6 @@ func authHelper(c *gin.Context, minRole int) {
return return
} }
} }
if !useAccessToken {
// get header New-Api-User
apiUserIdStr := c.Request.Header.Get("New-Api-User")
if apiUserIdStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,请刷新页面或清空缓存后重试",
})
c.Abort()
return
}
apiUserId, err := strconv.Atoi(apiUserIdStr)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,登录信息无效,请重新登录",
})
c.Abort()
return
}
if id != apiUserId {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无权进行此操作,与登录用户不匹配,请重新登录",
})
c.Abort()
return
}
}
if status.(int) == common.UserStatusDisabled { if status.(int) == common.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -104,14 +50,6 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort() c.Abort()
return return
} }
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户 LINUX DO 信任等级不足",
})
c.Abort()
return
}
if role.(int) < minRole { if role.(int) < minRole {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -120,33 +58,12 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort() c.Abort()
return return
} }
if !validUserInfo(username.(string), role.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
c.Set("username", username) c.Set("username", username)
c.Set("role", role) c.Set("role", role)
c.Set("id", id) c.Set("id", id)
c.Set("group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
c.Next() c.Next()
} }
func TryUserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
session := sessions.Default(c)
id := session.Get("id")
if id != nil {
c.Set("id", id)
}
c.Next()
}
}
func UserAuth() func(c *gin.Context) { func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
authHelper(c, common.RoleCommonUser) authHelper(c, common.RoleCommonUser)
@ -182,12 +99,6 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0] key = parts[0]
} }
token, err := model.ValidateUserToken(key) token, err := model.ValidateUserToken(key)
if token != nil {
id := c.GetInt("id")
if id == 0 {
c.Set("id", token.UserId)
}
}
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return return
@ -201,15 +112,6 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return return
} }
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !linuxDoEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
return
}
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
@ -223,11 +125,9 @@ func TokenAuth() func(c *gin.Context) {
} else { } else {
c.Set("token_model_limit_enabled", false) c.Set("token_model_limit_enabled", false)
} }
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1]) c.Set("channelId", parts[1])
} else { } else {
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return return

View File

@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
@ -22,38 +21,9 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
allowIpsMap := c.GetStringMap("allow_ips")
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
userId := c.GetInt("id") userId := c.GetInt("id")
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("specific_channel_id") channelId, ok := c.Get("channelId")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup, _ := model.CacheGetUserGroup(userId)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if _, ok := common.GroupRatio[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
userGroup = tokenGroup
}
c.Set("group", userGroup)
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
if err != nil { if err != nil {
@ -70,7 +40,72 @@ func Distribute() func(c *gin.Context) {
return return
} }
} else { } else {
shouldSelectChannel := true
// Select a channel for the user // Select a channel for the user
var modelRequest ModelRequest
var err error
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return
} else {
// task fetch, task fetch by condition, notify
shouldSelectChannel = false
}
}
modelRequest.Model = midjourneyModel
}
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
}
}
// check token model mapping // check token model mapping
modelLimitEnable := c.GetBool("token_model_limit_enabled") modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable { if modelLimitEnable {
@ -93,8 +128,10 @@ func Distribute() func(c *gin.Context) {
} }
} }
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
if shouldSelectChannel { if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0) channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil { if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题 // 如果错误,但是渠道不为空,说明是数据库一致性问题
@ -110,126 +147,36 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model)) abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return return
} }
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
} }
} }
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Next() c.Next()
} }
} }
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
var modelRequest ModelRequest
shouldSelectChannel := true
var err error
if strings.Contains(c.Request.URL.Path, "/mj/") {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description)
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else {
// task fetch, task fetch by condition, notify
shouldSelectChannel = false
}
}
modelRequest.Model = midjourneyModel
}
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/suno/") {
relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeSunoFetch ||
relayMode == relayconstant.RelayModeSunoFetchByID {
shouldSelectChannel = false
} else {
modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
modelRequest.Model = modelName
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error())
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode := relayconstant.RelayModeAudioSpeech
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranslation
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranscription
}
c.Set("relay_mode", relayMode)
}
return &modelRequest, shouldSelectChannel, nil
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("original_model", modelName) // for retry
if channel == nil {
return
}
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeVertexAi:
c.Set("region", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
case common.ChannelCloudflare:
c.Set("api_version", channel.Other)
}
}

View File

@ -1,13 +1,11 @@
package middleware package middleware
import ( import (
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
) )
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
userId := c.GetInt("id")
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"error": gin.H{ "error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
@ -15,7 +13,7 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
}, },
}) })
c.Abort() c.Abort()
common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) common.LogError(c.Request.Context(), message)
} }
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {

View File

@ -3,8 +3,6 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/samber/lo"
"gorm.io/gorm"
"one-api/common" "one-api/common"
"strings" "strings"
) )
@ -29,20 +27,8 @@ func GetGroupModels(group string) []string {
return models return models
} }
func GetEnabledModels() []string { func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
var models []string
// Find distinct models
DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models)
return models
}
func GetAllEnableAbilities() []Ability {
var abilities []Ability var abilities []Ability
DB.Find(&abilities, "enabled = ?", true)
return abilities
}
func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`" groupCol := "`group`"
trueVal := "1" trueVal := "1"
if common.UsingPostgreSQL { if common.UsingPostgreSQL {
@ -50,60 +36,9 @@ func getPriority(group string, model string, retry int) (int, error) {
trueVal = "true" trueVal = "true"
} }
var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil {
// 处理错误
return 0, err
}
if len(priorities) == 0 {
// 如果没有查询到优先级,则返回错误
return 0, errors.New("数据库一致性被破坏")
}
// 确定要使用的优先级
var priorityToUse int
if retry >= len(priorities) {
// 如果重试次数大于优先级数,则使用最小的优先级
priorityToUse = priorities[len(priorities)-1]
} else {
priorityToUse = priorities[retry]
}
return priorityToUse, nil
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
}
}
return channelQuery
}
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
var abilities []Ability
var err error = nil var err error = nil
channelQuery := getChannelQuery(group, model, retry) maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL { if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("weight DESC").Find(&abilities).Error err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else { } else {
@ -117,16 +52,21 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
// Randomly choose one // Randomly choose one
weightSum := uint(0) weightSum := uint(0)
for _, ability_ := range abilities { for _, ability_ := range abilities {
weightSum += ability_.Weight + 10 weightSum += ability_.Weight
} }
// Randomly choose one if weightSum == 0 {
weight := common.GetRandomInt(int(weightSum)) // All weight is 0, randomly choose one
for _, ability_ := range abilities { channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
weight -= int(ability_.Weight) + 10 } else {
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight) // Randomly choose one
if weight <= 0 { weight := common.GetRandomInt(int(weightSum))
channel.Id = ability_.ChannelId for _, ability_ := range abilities {
break weight -= int(ability_.Weight)
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break
}
} }
} }
} else { } else {
@ -153,16 +93,7 @@ func (channel *Channel) AddAbilities() error {
abilities = append(abilities, ability) abilities = append(abilities, ability)
} }
} }
if len(abilities) == 0 { return DB.Create(&abilities).Error
return nil
}
for _, chunk := range lo.Chunk(abilities, 50) {
err := DB.Create(&chunk).Error
if err != nil {
return err
}
}
return nil
} }
func (channel *Channel) DeleteAbilities() error { func (channel *Channel) DeleteAbilities() error {
@ -210,7 +141,7 @@ func FixAbility() (int, error) {
// Use channelIds to find channel not in abilities table // Use channelIds to find channel not in abilities table
var abilityChannelIds []int var abilityChannelIds []int
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error())) common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err return 0, err

View File

@ -25,6 +25,9 @@ var token2UserId = make(map[string]int)
var token2UserIdLock sync.RWMutex var token2UserIdLock sync.RWMutex
func cacheSetToken(token *Token) error { func cacheSetToken(token *Token) error {
if !common.RedisEnabled {
return token.SelectUpdate()
}
jsonBytes, err := json.Marshal(token) jsonBytes, err := json.Marshal(token)
if err != nil { if err != nil {
return err return err
@ -87,7 +90,7 @@ func SyncTokenCache(frequency int) {
} }
} else { } else {
// 如果数据库中存在先检查redis // 如果数据库中存在先检查redis
_, err = common.RedisGet(fmt.Sprintf("token:%s", key)) _, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil { if err != nil {
// 如果redis中不存在则跳过 // 如果redis中不存在则跳过
continue continue
@ -165,11 +168,7 @@ func CacheUpdateUserQuota(id int) error {
if err != nil { if err != nil {
return err return err
} }
return cacheSetUserQuota(id, quota) err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
}
func cacheSetUserQuota(id int, quota int) error {
err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
return err return err
} }
@ -205,30 +204,6 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err return userEnabled, err
} }
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
if !common.RedisEnabled {
return IsLinuxDoEnabled(userId)
}
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if linuxDoEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set linuxdo enabled error: " + err.Error())
}
return linuxDoEnabled, err
}
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex var channelSyncLock sync.RWMutex
@ -290,16 +265,14 @@ func SyncChannelCache(frequency int) {
} }
} }
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") { if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*" model = "gpt-4-gizmo-*"
} else if strings.HasPrefix(model, "g-") {
model = "g-*"
} }
// if memory cache is disabled, get channel directly from database // if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled { if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model, retry) return GetRandomSatisfiedChannel(group, model)
} }
channelSyncLock.RLock() channelSyncLock.RLock()
defer channelSyncLock.RUnlock() defer channelSyncLock.RUnlock()
@ -307,27 +280,15 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
if len(channels) == 0 { if len(channels) == 0 {
return nil, errors.New("channel not found") return nil, errors.New("channel not found")
} }
endIdx := len(channels)
uniquePriorities := make(map[int]bool) // choose by priority
for _, channel := range channels { firstChannel := channels[0]
uniquePriorities[int(channel.GetPriority())] = true if firstChannel.GetPriority() > 0 {
} for i := range channels {
var sortedUniquePriorities []int if channels[i].GetPriority() != firstChannel.GetPriority() {
for priority := range uniquePriorities { endIdx = i
sortedUniquePriorities = append(sortedUniquePriorities, priority) break
} }
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
if retry >= len(uniquePriorities) {
retry = len(uniquePriorities) - 1
}
targetPriority := int64(sortedUniquePriorities[retry])
// get the priority for the given retry number
var targetChannels []*Channel
for _, channel := range channels {
if channel.GetPriority() == targetPriority {
targetChannels = append(targetChannels, channel)
} }
} }
@ -335,14 +296,20 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
smoothingFactor := 10 smoothingFactor := 10
// Calculate the total weight of all channels up to endIdx // Calculate the total weight of all channels up to endIdx
totalWeight := 0 totalWeight := 0
for _, channel := range targetChannels { for _, channel := range channels[:endIdx] {
totalWeight += channel.GetWeight() + smoothingFactor totalWeight += channel.GetWeight() + smoothingFactor
} }
//if totalWeight == 0 {
// // If all weights are 0, select a channel randomly
// return channels[rand.Intn(endIdx)], nil
//}
// Generate a random value in the range [0, totalWeight) // Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight) randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight // Find a channel based on its weight
for _, channel := range targetChannels { for _, channel := range channels[:endIdx] {
randomWeight -= channel.GetWeight() + smoothingFactor randomWeight -= channel.GetWeight() + smoothingFactor
if randomWeight < 0 { if randomWeight < 0 {
return channel, nil return channel, nil

View File

@ -1,10 +1,8 @@
package model package model
import ( import (
"encoding/json"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"strings"
) )
type Channel struct { type Channel struct {
@ -12,7 +10,6 @@ type Channel struct {
Type int `json:"type" gorm:"default:0"` Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null"` Key string `json:"key" gorm:"not null"`
OpenAIOrganization *string `json:"openai_organization"` OpenAIOrganization *string `json:"openai_organization"`
TestModel *string `json:"test_model"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"` Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"` Weight *uint `json:"weight" gorm:"default:0"`
@ -27,49 +24,8 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(64);default:'default'"` Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
//MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"` AutoBan *int `json:"auto_ban" gorm:"default:1"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
}
func (channel *Channel) GetModels() []string {
if channel.Models == "" {
return []string{}
}
return strings.Split(strings.Trim(channel.Models, ","), ",")
}
func (channel *Channel) GetOtherInfo() map[string]interface{} {
otherInfo := make(map[string]interface{})
if channel.OtherInfo != "" {
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
if err != nil {
common.SysError("failed to unmarshal other info: " + err.Error())
}
}
return otherInfo
}
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
otherInfoBytes, err := json.Marshal(otherInfo)
if err != nil {
common.SysError("failed to marshal other info: " + err.Error())
return
}
channel.OtherInfo = string(otherInfoBytes)
}
func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
}
return *channel.AutoBan == 1
}
func (channel *Channel) Save() error {
return DB.Save(channel).Error
} }
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
@ -106,23 +62,16 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
// 构造WHERE子句 // 构造WHERE子句
var whereClause string var whereClause string
var args []interface{} var args []interface{}
if group != "" && group != "null" { if group != "" {
var groupCondition string whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?"
if common.UsingMySQL { args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%")
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%", "%,"+group+",%")
} else { } else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?" whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%") args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
} }
// 执行查询 // 执行查询
err := baseQuery.Where(whereClause, args...).Order("priority desc").Find(&channels).Error err := baseQuery.Where(whereClause, args...).Find(&channels).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -203,13 +152,6 @@ func (channel *Channel) GetModelMapping() string {
return *channel.ModelMapping return *channel.ModelMapping
} }
func (channel *Channel) GetStatusCodeMapping() string {
if channel.StatusCodeMapping == nil {
return ""
}
return *channel.StatusCodeMapping
}
func (channel *Channel) Insert() error { func (channel *Channel) Insert() error {
var err error var err error
err = DB.Create(channel).Error err = DB.Create(channel).Error
@ -261,31 +203,15 @@ func (channel *Channel) Delete() error {
return err return err
} }
func UpdateChannelStatusById(id int, status int, reason string) { func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil { if err != nil {
common.SysError("failed to update ability status: " + err.Error()) common.SysError("failed to update ability status: " + err.Error())
} }
channel, err := GetChannelById(id, true) err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil { if err != nil {
// find channel by id error, directly update status common.SysError("failed to update channel status: " + err.Error())
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
}
} else {
// find channel by id success, update status and other info
info := channel.GetOtherInfo()
info["status_reason"] = reason
info["status_time"] = common.GetTimestamp()
channel.SetOtherInfo(info)
channel.Status = status
err = channel.Save()
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
}
} }
} }
func UpdateChannelUsedQuota(id int, quota int) { func UpdateChannelUsedQuota(id int, quota int) {

View File

@ -3,12 +3,9 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"gorm.io/gorm"
"one-api/common" "one-api/common"
"strings" "strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
) )
type Log struct { type Log struct {
@ -27,7 +24,6 @@ type Log struct {
IsStream bool `json:"is_stream" gorm:"default:false"` IsStream bool `json:"is_stream" gorm:"default:false"`
ChannelId int `json:"channel" gorm:"index"` ChannelId int `json:"channel" gorm:"index"`
TokenId int `json:"token_id" gorm:"default:0;index"` TokenId int `json:"token_id" gorm:"default:0;index"`
Other string `json:"other"`
} }
const ( const (
@ -39,7 +35,7 @@ const (
) )
func GetLogByKey(key string) (logs []*Log, err error) { func GetLogByKey(key string) (logs []*Log, err error) {
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.Split(key, "-")[1]).Find(&logs).Error
return logs, err return logs, err
} }
@ -55,19 +51,18 @@ func RecordLog(userId int, logType int, content string) {
Type: logType, Type: logType,
Content: content, Content: content,
} }
err := LOG_DB.Create(log).Error err := DB.Create(log).Error
if err != nil { if err != nil {
common.SysError("failed to record log: " + err.Error()) common.SysError("failed to record log: " + err.Error())
} }
} }
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool, other map[string]interface{}) { func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled { if !common.LogConsumeEnabled {
return return
} }
username, _ := CacheGetUsername(userId) username, _ := CacheGetUsername(userId)
otherStr := common.MapToJsonStr(other)
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: username, Username: username,
@ -83,28 +78,27 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
TokenId: tokenId, TokenId: tokenId,
UseTime: useTimeSeconds, UseTime: useTimeSeconds,
IsStream: isStream, IsStream: isStream,
Other: otherStr,
} }
err := LOG_DB.Create(log).Error err := DB.Create(log).Error
if err != nil { if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error()) common.LogError(ctx, "failed to record log: "+err.Error())
} }
if common.DataExportEnabled { if common.DataExportEnabled {
gopool.Go(func() { common.SafeGoroutine(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens) LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
}) })
} }
} }
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) { func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = LOG_DB tx = DB
} else { } else {
tx = LOG_DB.Where("type = ?", logType) tx = DB.Where("type = ?", logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name like ?", modelName) tx = tx.Where("model_name = ?", modelName)
} }
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
@ -121,26 +115,19 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if channel != 0 { if channel != 0 {
tx = tx.Where("channel_id = ?", channel) tx = tx.Where("channel_id = ?", channel)
} }
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil { return logs, err
return nil, 0, err
}
return logs, total, err
} }
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) { func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = LOG_DB.Where("user_id = ?", userId) tx = DB.Where("user_id = ?", userId)
} else { } else {
tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) tx = DB.Where("user_id = ? and type = ?", userId, logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name like ?", modelName) tx = tx.Where("model_name = ?", modelName)
} }
if tokenName != "" { if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName) tx = tx.Where("token_name = ?", tokenName)
@ -151,30 +138,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if endTimestamp != 0 { if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
for i := range logs { return logs, err
var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other)
if otherMap != nil {
// delete admin
delete(otherMap, "admin_info")
}
logs[i].Other = common.MapToJsonStr(otherMap)
}
return logs, total, err
} }
func SearchAllLogs(keyword string) (logs []*Log, err error) { func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
return logs, err return logs, err
} }
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err return logs, err
} }
@ -185,18 +159,12 @@ type Stat struct {
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := LOG_DB.Table("logs").Select("sum(quota) quota") tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
// 为rpm和tpm创建单独的查询
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
} }
if tokenName != "" { if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName) tx = tx.Where("token_name = ?", tokenName)
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
} }
if startTimestamp != 0 { if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp) tx = tx.Where("created_at >= ?", startTimestamp)
@ -205,29 +173,17 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name like ?", modelName) tx = tx.Where("model_name = ?", modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName)
} }
if channel != 0 { if channel != 0 {
tx = tx.Where("channel_id = ?", channel) tx = tx.Where("channel_id = ?", channel)
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
} }
tx.Where("type = ?", LogTypeConsume).Scan(&stat)
tx = tx.Where("type = ?", LogTypeConsume)
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
// 只统计最近60秒的rpm和tpm
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
// 执行查询
tx.Scan(&stat)
rpmTpmQuery.Scan(&stat)
return stat return stat
} }
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }
@ -247,25 +203,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
return token return token
} }
func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) { func DeleteOldLog(targetTimestamp int64) (int64, error) {
var total int64 = 0 result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
for {
if nil != ctx.Err() {
return total, ctx.Err()
}
result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{})
if nil != result.Error {
return total, result.Error
}
total += result.RowsAffected
if result.RowsAffected < int64(limit) {
break
}
}
return total, nil
} }

View File

@ -15,8 +15,6 @@ import (
var DB *gorm.DB var DB *gorm.DB
var LOG_DB *gorm.DB
func createRootAccountIfNeed() error { func createRootAccountIfNeed() error {
var user User var user User
//if user.Status != common.UserStatusEnabled { //if user.Status != common.UserStatusEnabled {
@ -32,7 +30,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser, Role: common.RoleRootUser,
Status: common.UserStatusEnabled, Status: common.UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: nil, AccessToken: common.GetUUID(),
Quota: 100000000, Quota: 100000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@ -40,9 +38,9 @@ func createRootAccountIfNeed() error {
return nil return nil
} }
func chooseDB(envName string) (*gorm.DB, error) { func chooseDB() (*gorm.DB, error) {
dsn := os.Getenv(envName) if os.Getenv("SQL_DSN") != "" {
if dsn != "" { dsn := os.Getenv("SQL_DSN")
if strings.HasPrefix(dsn, "postgres://") { if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL // Use PostgreSQL
common.SysLog("using PostgreSQL as database") common.SysLog("using PostgreSQL as database")
@ -54,13 +52,6 @@ func chooseDB(envName string) (*gorm.DB, error) {
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL // Use MySQL
common.SysLog("using MySQL as database") common.SysLog("using MySQL as database")
// check parseTime // check parseTime
@ -85,7 +76,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
} }
func InitDB() (err error) { func InitDB() (err error) {
db, err := chooseDB("SQL_DSN") db, err := chooseDB()
if err == nil { if err == nil {
if common.DebugEnabled { if common.DebugEnabled {
db = db.Debug() db = db.Debug()
@ -95,58 +86,62 @@ func InitDB() (err error) {
if err != nil { if err != nil {
return err return err
} }
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100)) sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60))) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode { if !common.IsMasterNode {
return nil return nil
} }
//if common.UsingMySQL { if common.UsingMySQL {
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//}
common.SysLog("database migration started")
err = migrateDB()
return err
} else {
common.FatalLog(err)
}
return err
}
func InitLogDB() (err error) {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
db, err := chooseDB("LOG_SQL_DSN")
if err == nil {
if common.DebugEnabled {
db = db.Debug()
} }
LOG_DB = db common.SysLog("database migration started")
sqlDB, err := LOG_DB.DB() err = db.AutoMigrate(&Channel{})
if err != nil { if err != nil {
return err return err
} }
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100)) err = db.AutoMigrate(&Token{})
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000)) if err != nil {
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60))) return err
if !common.IsMasterNode {
return nil
} }
//if common.UsingMySQL { err = db.AutoMigrate(&User{})
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded if err != nil {
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded return err
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded }
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded err = db.AutoMigrate(&Option{})
//} if err != nil {
common.SysLog("database migration started") return err
err = migrateLOGDB() }
err = db.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return err
}
err = db.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = db.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = db.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err return err
} else { } else {
common.FatalLog(err) common.FatalLog(err)
@ -154,66 +149,8 @@ func InitLogDB() (err error) {
return err return err
} }
func migrateDB() error { func CloseDB() error {
err := DB.AutoMigrate(&Channel{}) sqlDB, err := DB.DB()
if err != nil {
return err
}
err = DB.AutoMigrate(&Token{})
if err != nil {
return err
}
err = DB.AutoMigrate(&User{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Option{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Log{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = DB.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = DB.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Task{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func closeDB(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil { if err != nil {
return err return err
} }
@ -221,16 +158,6 @@ func closeDB(db *gorm.DB) error {
return err return err
} }
func CloseDB() error {
if LOG_DB != DB {
err := closeDB(LOG_DB)
if err != nil {
return err
}
}
return closeDB(DB)
}
var ( var (
lastPingTime time.Time lastPingTime time.Time
pingMutex sync.Mutex pingMutex sync.Mutex

View File

@ -31,30 +31,25 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled) common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = "" common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = "" common.OptionMap["SMTPFrom"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
common.OptionMap["SMTPAccount"] = "" common.OptionMap["SMTPAccount"] = ""
common.OptionMap["SMTPToken"] = "" common.OptionMap["SMTPToken"] = ""
common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled)
common.OptionMap["Notice"] = "" common.OptionMap["Notice"] = ""
common.OptionMap["About"] = "" common.OptionMap["About"] = ""
common.OptionMap["HomePageContent"] = "" common.OptionMap["HomePageContent"] = ""
@ -62,20 +57,15 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = "" common.OptionMap["ServerAddress"] = ""
common.OptionMap["OutProxyUrl"] = "" common.OptionMap["PayAddress"] = ""
common.OptionMap["StripeApiSecret"] = common.StripeApiSecret common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["StripeWebhookSecret"] = common.StripeWebhookSecret common.OptionMap["EpayId"] = ""
common.OptionMap["StripePriceId"] = common.StripePriceId common.OptionMap["EpayKey"] = ""
common.OptionMap["PaymentEnabled"] = strconv.FormatBool(common.PaymentEnabled) common.OptionMap["Price"] = strconv.FormatFloat(common.Price, 'f', -1, 64)
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp) common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = constant.Chats2JsonString()
common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = "" common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["LinuxDoClientId"] = ""
common.OptionMap["LinuxDoClientSecret"] = ""
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
common.OptionMap["TelegramBotToken"] = "" common.OptionMap["TelegramBotToken"] = ""
common.OptionMap["TelegramBotName"] = "" common.OptionMap["TelegramBotName"] = ""
common.OptionMap["WeChatServerAddress"] = "" common.OptionMap["WeChatServerAddress"] = ""
@ -91,8 +81,6 @@ func InitOptionMap() {
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["ChatLink2"] = common.ChatLink2 common.OptionMap["ChatLink2"] = common.ChatLink2
@ -102,13 +90,9 @@ func InitOptionMap() {
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled) common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString() common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength) common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
@ -179,8 +163,6 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailVerificationEnabled = boolValue common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue common.GitHubOAuthEnabled = boolValue
case "LinuxDoOAuthEnabled":
common.LinuxDoOAuthEnabled = boolValue
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue common.WeChatAuthEnabled = boolValue
case "TelegramOAuthEnabled": case "TelegramOAuthEnabled":
@ -189,12 +171,8 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileCheckEnabled = boolValue common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled": case "RegisterEnabled":
common.RegisterEnabled = boolValue common.RegisterEnabled = boolValue
case "UserSelfDeletionEnabled":
common.UserSelfDeletionEnabled = boolValue
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue common.EmailDomainRestrictionEnabled = boolValue
case "EmailAliasRestrictionEnabled":
common.EmailAliasRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled": case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue common.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled": case "AutomaticEnableChannelEnabled":
@ -207,32 +185,20 @@ func updateOptionMap(key string, value string) (err error) {
common.DisplayTokenStatEnabled = boolValue common.DisplayTokenStatEnabled = boolValue
case "DrawingEnabled": case "DrawingEnabled":
common.DrawingEnabled = boolValue common.DrawingEnabled = boolValue
case "TaskEnabled":
common.TaskEnabled = boolValue
case "DataExportEnabled": case "DataExportEnabled":
common.DataExportEnabled = boolValue common.DataExportEnabled = boolValue
case "DefaultCollapseSidebar": case "DefaultCollapseSidebar":
common.DefaultCollapseSidebar = boolValue common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled": case "MjNotifyEnabled":
constant.MjNotifyEnabled = boolValue constant.MjNotifyEnabled = boolValue
case "MjAccountFilterEnabled":
constant.MjAccountFilterEnabled = boolValue
case "MjModeClearEnabled":
constant.MjModeClearEnabled = boolValue
case "MjForwardUrlEnabled":
constant.MjForwardUrlEnabled = boolValue
case "MjActionCheckSuccessEnabled":
constant.MjActionCheckSuccessEnabled = boolValue
case "CheckSensitiveEnabled": case "CheckSensitiveEnabled":
constant.CheckSensitiveEnabled = boolValue constant.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled": case "CheckSensitiveOnPromptEnabled":
constant.CheckSensitiveOnPromptEnabled = boolValue constant.CheckSensitiveOnPromptEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled": case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled": case "StopOnSensitiveEnabled":
constant.StopOnSensitiveEnabled = boolValue constant.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled":
common.SMTPSSLEnabled = boolValue
} }
} }
switch key { switch key {
@ -251,20 +217,16 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPToken = value common.SMTPToken = value
case "ServerAddress": case "ServerAddress":
common.ServerAddress = value common.ServerAddress = value
case "OutProxyUrl": case "PayAddress":
common.OutProxyUrl = value common.PayAddress = value
case "Chats": case "CustomCallbackAddress":
err = constant.UpdateChatsByJsonString(value) common.CustomCallbackAddress = value
case "StripeApiSecret": case "EpayId":
common.StripeApiSecret = value common.EpayId = value
case "StripeWebhookSecret": case "EpayKey":
common.StripeWebhookSecret = value common.EpayKey = value
case "StripePriceId": case "Price":
common.StripePriceId = value common.Price, _ = strconv.ParseFloat(value, 64)
case "PaymentEnabled":
common.PaymentEnabled, _ = strconv.ParseBool(value)
case "StripeUnitPrice":
common.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
case "MinTopUp": case "MinTopUp":
common.MinTopUp, _ = strconv.Atoi(value) common.MinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio": case "TopupGroupRatio":
@ -273,12 +235,6 @@ func updateOptionMap(key string, value string) (err error) {
common.GitHubClientId = value common.GitHubClientId = value
case "GitHubClientSecret": case "GitHubClientSecret":
common.GitHubClientSecret = value common.GitHubClientSecret = value
case "LinuxDoClientId":
common.LinuxDoClientId = value
case "LinuxDoClientSecret":
common.LinuxDoClientSecret = value
case "LinuxDoMinLevel":
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
case "Footer": case "Footer":
common.Footer = value common.Footer = value
case "SystemName": case "SystemName":
@ -319,10 +275,6 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = common.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = common.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice": case "ModelPrice":
err = common.UpdateModelPriceByJSONString(value) err = common.UpdateModelPriceByJSONString(value)
case "TopUpLink": case "TopUpLink":

View File

@ -1,79 +0,0 @@
package model
import (
"one-api/common"
"sync"
"time"
)
type Pricing struct {
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_groups,omitempty"`
}
var (
pricingMap []Pricing
lastGetPricingTime time.Time
updatePricingLock sync.Mutex
)
func GetPricing() []Pricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing()
}
//if group != "" {
// userPricingMap := make([]Pricing, 0)
// models := GetGroupModels(group)
// for _, pricing := range pricingMap {
// if !common.StringsContains(models, pricing.ModelName) {
// pricing.Available = false
// }
// userPricingMap = append(userPricingMap, pricing)
// }
// return userPricingMap
//}
return pricingMap
}
func updatePricing() {
//modelRatios := common.GetModelRatios()
enableAbilities := GetAllEnableAbilities()
modelGroupsMap := make(map[string][]string)
for _, ability := range enableAbilities {
groups := modelGroupsMap[ability.Model]
if groups == nil {
groups = make([]string, 0)
}
if !common.StringsContains(groups, ability.Group) {
groups = append(groups, ability.Group)
}
modelGroupsMap[ability.Model] = groups
}
pricingMap = make([]Pricing, 0)
for model, groups := range modelGroupsMap {
pricing := Pricing{
ModelName: model,
EnableGroup: groups,
}
modelPrice, findPrice := common.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
pricing.ModelRatio = common.GetModelRatio(model)
pricing.CompletionRatio = common.GetCompletionRatio(model)
pricing.QuotaType = 0
}
pricingMap = append(pricingMap, pricing)
}
lastGetPricingTime = time.Now()
}

View File

@ -56,7 +56,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if common.UsingPostgreSQL { if common.UsingPostgreSQL {
keyCol = `"key"` keyCol = `"key"`
} }
common.RandomSleep()
err = DB.Transaction(func(tx *gorm.DB) error { err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil { if err != nil {
@ -78,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil { if err != nil {
return 0, errors.New("兑换失败," + err.Error()) return 0, errors.New("兑换失败," + err.Error())
} }
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id)) RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
return redemption.Quota, nil return redemption.Quota, nil
} }

View File

@ -1,304 +0,0 @@
package model
import (
"database/sql/driver"
"encoding/json"
"one-api/constant"
commonRelay "one-api/relay/common"
"time"
)
type TaskStatus string
const (
TaskStatusNotStart TaskStatus = "NOT_START"
TaskStatusSubmitted = "SUBMITTED"
TaskStatusQueued = "QUEUED"
TaskStatusInProgress = "IN_PROGRESS"
TaskStatusFailure = "FAILURE"
TaskStatusSuccess = "SUCCESS"
TaskStatusUnknown = "UNKNOWN"
)
type Task struct {
ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
CreatedAt int64 `json:"created_at" gorm:"index"`
UpdatedAt int64 `json:"updated_at"`
TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id不一定有/ song id\ Task id
Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
UserId int `json:"user_id" gorm:"index"`
ChannelId int `json:"channel_id" gorm:"index"`
Quota int `json:"quota"`
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态
FailReason string `json:"fail_reason"`
SubmitTime int64 `json:"submit_time" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
Progress string `json:"progress" gorm:"type:varchar(20);index"`
Properties Properties `json:"properties" gorm:"type:json"`
Data json.RawMessage `json:"data" gorm:"type:json"`
}
func (t *Task) SetData(data any) {
b, _ := json.Marshal(data)
t.Data = json.RawMessage(b)
}
func (t *Task) GetData(v any) error {
err := json.Unmarshal(t.Data, &v)
return err
}
type Properties struct {
Input string `json:"input"`
}
func (m *Properties) Scan(val interface{}) error {
bytesValue, _ := val.([]byte)
return json.Unmarshal(bytesValue, m)
}
func (m Properties) Value() (driver.Value, error) {
return json.Marshal(m)
}
// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
type SyncTaskQueryParams struct {
Platform constant.TaskPlatform
ChannelID string
TaskID string
UserID string
Action string
Status string
StartTimestamp int64
EndTimestamp int64
UserIDs []int
}
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
t := &Task{
UserId: relayInfo.UserId,
SubmitTime: time.Now().Unix(),
Status: TaskStatusNotStart,
Progress: "0%",
ChannelId: relayInfo.ChannelId,
Platform: platform,
}
return t
}
func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
var tasks []*Task
var err error
// 初始化查询构建器
query := DB.Where("user_id = ?", userId)
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.Platform != "" {
query = query.Where("platform = ?", queryParams.Platform)
}
if queryParams.StartTimestamp != 0 {
// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
// 获取数据
err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
var tasks []*Task
var err error
// 初始化查询构建器
query := DB
// 添加过滤条件
if queryParams.ChannelID != "" {
query = query.Where("channel_id = ?", queryParams.ChannelID)
}
if queryParams.Platform != "" {
query = query.Where("platform = ?", queryParams.Platform)
}
if queryParams.UserID != "" {
query = query.Where("user_id = ?", queryParams.UserID)
}
if len(queryParams.UserIDs) != 0 {
query = query.Where("user_id in (?)", queryParams.UserIDs)
}
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.StartTimestamp != 0 {
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
// 获取数据
err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func GetAllUnFinishSyncTasks(limit int) []*Task {
var tasks []*Task
var err error
// get all tasks progress is not 100%
err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
if taskId == "" {
return nil, false, nil
}
var task *Task
var err error
err = DB.Where("task_id = ?", taskId).First(&task).Error
exist, err := RecordExist(err)
if err != nil {
return nil, false, err
}
return task, exist, err
}
func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
if taskId == "" {
return nil, false, nil
}
var task *Task
var err error
err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
First(&task).Error
exist, err := RecordExist(err)
if err != nil {
return nil, false, err
}
return task, exist, err
}
func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
if len(taskIds) == 0 {
return nil, nil
}
var task []*Task
var err error
err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
Find(&task).Error
if err != nil {
return nil, err
}
return task, nil
}
func TaskUpdateProgress(id int64, progress string) error {
return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
}
func (Task *Task) Insert() error {
var err error
err = DB.Create(Task).Error
return err
}
func (Task *Task) Update() error {
var err error
err = DB.Save(Task).Error
return err
}
func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
if len(TaskIds) == 0 {
return nil
}
return DB.Model(&Task{}).
Where("task_id in (?)", TaskIds).
Updates(params).Error
}
func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
if len(taskIDs) == 0 {
return nil
}
return DB.Model(&Task{}).
Where("id in (?)", taskIDs).
Updates(params).Error
}
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
if len(ids) == 0 {
return nil
}
return DB.Model(&Task{}).
Where("id in (?)", ids).
Updates(params).Error
}
type TaskQuotaUsage struct {
Mode string `json:"mode"`
Count float64 `json:"count"`
}
func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
query := DB.Model(Task{})
// 添加过滤条件
if queryParams.ChannelID != "" {
query = query.Where("channel_id = ?", queryParams.ChannelID)
}
if queryParams.UserID != "" {
query = query.Where("user_id = ?", queryParams.UserID)
}
if len(queryParams.UserIDs) != 0 {
query = query.Where("user_id in (?)", queryParams.UserIDs)
}
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.StartTimestamp != 0 {
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
return stat, err
}

View File

@ -5,14 +5,13 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
relaycommon "one-api/relay/common"
"strconv" "strconv"
"strings" "strings"
) )
type Token struct { type Token struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"` Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" ` Name string `json:"name" gorm:"index" `
@ -23,34 +22,10 @@ type Token struct {
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"` ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"` ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Group string `json:"group" gorm:"default:''"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
} }
func (token *Token) GetIpLimitsMap() map[string]any {
// delete empty spaces
//split with \n
ipLimitsMap := make(map[string]any)
if token.AllowIps == nil {
return ipLimitsMap
}
cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
if cleanIps == "" {
return ipLimitsMap
}
ips := strings.Split(cleanIps, "\n")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
ip = strings.ReplaceAll(ip, ",", "")
if common.IsIP(ip) {
ipLimitsMap[ip] = true
}
}
return ipLimitsMap
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
var tokens []*Token var tokens []*Token
var err error var err error
@ -75,12 +50,12 @@ func ValidateUserToken(key string) (token *Token, err error) {
if token.Status == common.TokenStatusExhausted { if token.Status == common.TokenStatusExhausted {
keyPrefix := key[:3] keyPrefix := key[:3]
keySuffix := key[len(key)-3:] keySuffix := key[len(key)-3:]
return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
} else if token.Status == common.TokenStatusExpired { } else if token.Status == common.TokenStatusExpired {
return token, errors.New("该令牌已过期") return nil, errors.New("该令牌已过期")
} }
if token.Status != common.TokenStatusEnabled { if token.Status != common.TokenStatusEnabled {
return token, errors.New("该令牌状态不可用") return nil, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if !common.RedisEnabled { if !common.RedisEnabled {
@ -90,7 +65,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysError("failed to update token status" + err.Error()) common.SysError("failed to update token status" + err.Error())
} }
} }
return token, errors.New("该令牌已过期") return nil, errors.New("该令牌已过期")
} }
if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled { if !common.RedisEnabled {
@ -103,7 +78,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
} }
keyPrefix := key[:3] keyPrefix := key[:3]
keySuffix := key[len(key)-3:] keySuffix := key[len(key)-3:]
return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
} }
return token, nil return token, nil
} }
@ -127,11 +102,6 @@ func GetTokenById(id int) (*Token, error) {
token := Token{Id: id} token := Token{Id: id}
var err error = nil var err error = nil
err = DB.First(&token, "id = ?", id).Error err = DB.First(&token, "id = ?", id).Error
if err != nil {
if common.RedisEnabled {
go cacheSetToken(&token)
}
}
return &token, err return &token, err
} }
@ -154,8 +124,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values // Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error { func (token *Token) Update() error {
var err error var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
return err return err
} }
@ -257,52 +226,51 @@ func decreaseTokenQuota(id int, quota int) (err error) {
return err return err
} }
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) { func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
if quota < 0 { if quota < 0 {
return 0, errors.New("quota 不能为负数!") return 0, errors.New("quota 不能为负数!")
} }
if !relayInfo.IsPlayground { token, err := GetTokenById(tokenId)
token, err := GetTokenById(relayInfo.TokenId) if err != nil {
if err != nil { return 0, err
return 0, err
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
} }
userQuota, err = GetUserQuota(relayInfo.UserId) if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
userQuota, err = GetUserQuota(token.UserId)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if userQuota < quota { if userQuota < quota {
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
} }
if !relayInfo.IsPlayground { if !token.UnlimitedQuota {
err = DecreaseTokenQuota(relayInfo.TokenId, quota) err = DecreaseTokenQuota(tokenId, quota)
if err != nil { if err != nil {
return 0, err return 0, err
} }
} }
err = DecreaseUserQuota(relayInfo.UserId, quota) err = DecreaseUserQuota(token.UserId, quota)
return userQuota - quota, err return userQuota - quota, err
} }
func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 { if quota > 0 {
err = DecreaseUserQuota(relayInfo.UserId, quota) err = DecreaseUserQuota(token.UserId, quota)
} else { } else {
err = IncreaseUserQuota(relayInfo.UserId, -quota) err = IncreaseUserQuota(token.UserId, -quota)
} }
if err != nil { if err != nil {
return err return err
} }
if !relayInfo.IsPlayground { if !token.UnlimitedQuota {
if quota > 0 { if quota > 0 {
err = DecreaseTokenQuota(relayInfo.TokenId, quota) err = DecreaseTokenQuota(tokenId, quota)
} else { } else {
err = IncreaseTokenQuota(relayInfo.TokenId, -quota) err = IncreaseTokenQuota(tokenId, -quota)
} }
if err != nil { if err != nil {
return err return err
@ -315,7 +283,7 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0 noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota { if quotaTooLow || noMoreQuota {
go func() { go func() {
email, err := GetUserEmail(relayInfo.UserId) email, err := GetUserEmail(token.UserId)
if err != nil { if err != nil {
common.SysError("failed to fetch user email: " + err.Error()) common.SysError("failed to fetch user email: " + err.Error())
} }

View File

@ -1,21 +1,13 @@
package model package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
)
type TopUp struct { type TopUp struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id" gorm:"index"`
Amount int `json:"amount"` Amount int `json:"amount"`
Money float64 `json:"money"` Money float64 `json:"money"`
TradeNo string `json:"trade_no" gorm:"unique"` TradeNo string `json:"trade_no"`
CreateTime int64 `json:"create_time"` CreateTime int64 `json:"create_time"`
CompleteTime int64 `json:"complete_time"` Status string `json:"status"`
Status string `json:"status"`
} }
func (topUp *TopUp) Insert() error { func (topUp *TopUp) Insert() error {
@ -49,51 +41,3 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
} }
return topUp return topUp
} }
func Recharge(referenceId string, customerId string) (err error) {
if referenceId == "" {
return errors.New("未提供支付单号")
}
var quota float64
topUp := &TopUp{}
refCol := "`trade_no`"
if common.UsingPostgreSQL {
refCol = `"trade_no"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
if err != nil {
return errors.New("充值订单不存在")
}
if topUp.Status != common.TopUpStatusPending {
return errors.New("充值订单状态错误")
}
topUp.CompleteTime = common.GetTimestamp()
topUp.Status = common.TopUpStatusSuccess
err = tx.Save(topUp).Error
if err != nil {
return err
}
quota = topUp.Money * common.QuotaPerUnit
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
if err != nil {
return err
}
return nil
})
if err != nil {
return errors.New("充值失败," + err.Error())
}
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%d", common.LogQuotaF(quota), topUp.Amount))
return nil
}

View File

@ -45,7 +45,6 @@ func logQuotaDataCache(userId int, username string, modelName string, quota int,
if ok { if ok {
quotaData.Count += 1 quotaData.Count += 1
quotaData.Quota += quota quotaData.Quota += quota
quotaData.TokenUsed += tokenUsed
} else { } else {
quotaData = &QuotaData{ quotaData = &QuotaData{
UserID: userId, UserID: userId,

View File

@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"one-api/common" "one-api/common"
"strconv"
"strings" "strings"
"time" "time"
@ -22,12 +21,10 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"` Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"` GitHubId string `json:"github_id" gorm:"column:github_id;index"`
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"` Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
@ -37,21 +34,9 @@ type User struct {
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度 AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度 AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
StripeCustomer string `json:"stripe_customer" gorm:"column:stripe_customer;index"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
} }
func (user *User) GetAccessToken() string {
if user.AccessToken == nil {
return ""
}
return *user.AccessToken
}
func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) { func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User var user User
@ -78,7 +63,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
func GetMaxUserId() int { func GetMaxUserId() int {
var user User var user User
DB.Unscoped().Last(&user) DB.Last(&user)
return user.Id return user.Id
} }
@ -87,40 +72,8 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
return users, err return users, err
} }
func SearchUsers(keyword string, group string) ([]*User, error) { func SearchUsers(keyword string) (users []*User, err error) {
var users []*User err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
var err error
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
// 尝试将关键字转换为整数ID
keywordInt, err := strconv.Atoi(keyword)
if err == nil {
// 如果转换成功按照ID和可选的组别搜索用户
query := DB.Unscoped().Omit("password").Where("id = ?", keywordInt)
if group != "" {
query = query.Where(groupCol+" = ?", group) // 使用反引号包围group
}
err = query.Find(&users).Error
if err != nil || len(users) > 0 {
return users, err
}
}
err = nil
query := DB.Unscoped().Omit("password")
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
if group != "" {
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
err = query.Find(&users).Error
return users, err return users, err
} }
@ -138,20 +91,6 @@ func GetUserById(id int, selectAll bool) (*User, error) {
return &user, err return &user, err
} }
func GetUserByIdUnscoped(id int, selectAll bool) (*User, error) {
if id == 0 {
return nil, errors.New("id 为空!")
}
user := User{Id: id}
var err error = nil
if selectAll {
err = DB.Unscoped().First(&user, "id = ?", id).Error
} else {
err = DB.Unscoped().Omit("password").First(&user, "id = ?", id).Error
}
return &user, err
}
func GetUserIdByAffCode(affCode string) (int, error) { func GetUserIdByAffCode(affCode string) (int, error) {
if affCode == "" { if affCode == "" {
return 0, errors.New("affCode 为空!") return 0, errors.New("affCode 为空!")
@ -234,7 +173,7 @@ func (user *User) Insert(inviterId int) error {
} }
} }
user.Quota = common.QuotaForNewUser user.Quota = common.QuotaForNewUser
//user.SetAccessToken(common.GetUUID()) user.AccessToken = common.GetUUID()
user.AffCode = common.GetRandomString(4) user.AffCode = common.GetRandomString(4)
result := DB.Create(user) result := DB.Create(user)
if result.Error != nil { if result.Error != nil {
@ -271,36 +210,6 @@ func (user *User) Update(updatePassword bool) error {
if err == nil { if err == nil {
if common.RedisEnabled { if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second) _ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
}
}
return err
}
func (user *User) Edit(updatePassword bool) error {
var err error
if updatePassword {
user.Password, err = common.Password2Hash(user.Password)
if err != nil {
return err
}
}
newUser := *user
updates := map[string]interface{}{
"username": newUser.Username,
"display_name": newUser.DisplayName,
"group": newUser.Group,
"quota": newUser.Quota,
}
if updatePassword {
updates["password"] = newUser.Password
}
DB.First(&user, user.Id)
err = DB.Model(user).Updates(updates).Error
if err == nil {
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
} }
} }
return err return err
@ -328,12 +237,10 @@ func (user *User) ValidateAndFill() (err error) {
// that means if your fields value is 0, '', false or other zero values, // that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions // it wont be used to build query conditions
password := user.Password password := user.Password
username := strings.TrimSpace(user.Username) if user.Username == "" || password == "" {
if username == "" || password == "" {
return errors.New("用户名或密码为空") return errors.New("用户名或密码为空")
} }
// find buy username or email DB.Where(User{Username: user.Username}).First(user)
DB.Where("username = ? OR email = ?", username, username).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password) okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled { if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁") return errors.New("用户名或密码错误,或用户已被封禁")
@ -365,14 +272,6 @@ func (user *User) FillUserByGitHubId() error {
return nil return nil
} }
func (user *User) FillUserByLinuxDoId() error {
if user.LinuxDoId == "" {
return errors.New("LINUX DO id 为空!")
}
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error { func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" { if user.WeChatId == "" {
return errors.New("WeChat id 为空!") return errors.New("WeChat id 为空!")
@ -381,6 +280,14 @@ func (user *User) FillUserByWeChatId() error {
return nil return nil
} }
func (user *User) FillUserByUsername() error {
if user.Username == "" {
return errors.New("username 为空!")
}
DB.Where(User{Username: user.Username}).First(user)
return nil
}
func (user *User) FillUserByTelegramId() error { func (user *User) FillUserByTelegramId() error {
if user.TelegramId == "" { if user.TelegramId == "" {
return errors.New("Telegram id 为空!") return errors.New("Telegram id 为空!")
@ -393,27 +300,23 @@ func (user *User) FillUserByTelegramId() error {
} }
func IsEmailAlreadyTaken(email string) bool { func IsEmailAlreadyTaken(email string) bool {
return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1 return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
} }
func IsWeChatIdAlreadyTaken(wechatId string) bool { func IsWeChatIdAlreadyTaken(wechatId string) bool {
return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
} }
func IsGitHubIdAlreadyTaken(githubId string) bool { func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
return DB.Unscoped().Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
} }
func IsUsernameAlreadyTaken(username string) bool { func IsUsernameAlreadyTaken(username string) bool {
return DB.Unscoped().Where("username = ?", username).Find(&User{}).RowsAffected == 1 return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
} }
func IsTelegramIdAlreadyTaken(telegramId string) bool { func IsTelegramIdAlreadyTaken(telegramId string) bool {
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1 return DB.Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
} }
func ResetUserPasswordByEmail(email string, password string) error { func ResetUserPasswordByEmail(email string, password string) error {
@ -453,18 +356,6 @@ func IsUserEnabled(userId int) (bool, error) {
return user.Status == common.UserStatusEnabled, nil return user.Status == common.UserStatusEnabled, nil
} }
func IsLinuxDoEnabled(userId int) (bool, error) {
if userId == 0 {
return false, errors.New("user id is empty")
}
var user User
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
if err != nil {
return false, err
}
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
}
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
if token == "" { if token == "" {
return nil return nil
@ -479,11 +370,6 @@ func ValidateAccessToken(token string) (user *User) {
func GetUserQuota(id int) (quota int, err error) { func GetUserQuota(id int) (quota int, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
if err != nil {
if common.RedisEnabled {
go cacheSetUserQuota(id, quota)
}
}
return quota, err return quota, err
} }

View File

@ -1,9 +1,6 @@
package model package model
import ( import (
"errors"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common" "one-api/common"
"sync" "sync"
"time" "time"
@ -29,12 +26,12 @@ func init() {
} }
func InitBatchUpdater() { func InitBatchUpdater() {
gopool.Go(func() { go func() {
for { for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
batchUpdate() batchUpdate()
} }
}) }()
} }
func addNewRecord(type_ int, id int, value int) { func addNewRecord(type_ int, id int, value int) {
@ -78,13 +75,3 @@ func batchUpdate() {
} }
common.SysLog("batch update finished") common.SysLog("batch update finished")
} }
func RecordExist(err error) (bool, error) {
if err == nil {
return true, nil
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
return false, err
}

View File

@ -10,34 +10,12 @@ import (
type Adaptor interface { type Adaptor interface {
// Init IsStream bool // Init IsStream bool
Init(info *relaycommon.RelayInfo) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
GetRequestURL(info *relaycommon.RelayInfo) (string, error) GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse)
GetModelList() []string GetModelList() []string
GetChannelName() string GetChannelName() string
} }
type TaskAdaptor interface {
Init(info *relaycommon.TaskRelayInfo)
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
GetModelList() []string
GetChannelName() string
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
}

View File

@ -1,13 +1,8 @@
package ai360 package ai360
var ModelList = []string{ var ModelList = []string{
"360gpt-turbo",
"360gpt-turbo-responsibility-8k",
"360gpt-pro",
"360GPT_S2_V9", "360GPT_S2_V9",
"embedding-bert-512-v1", "embedding-bert-512-v1",
"embedding_s1_v1", "embedding_s1_v1",
"semantic_similarity_s1_v1", "semantic_similarity_s1_v1",
} }
var ChannelName = "ai360"

View File

@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
) )
@ -16,18 +15,14 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
switch info.RelayMode { if info.RelayMode == constant.RelayModeEmbeddings {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
} }
return fullRequestURL, nil return fullRequestURL, nil
} }
@ -44,49 +39,33 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch info.RelayMode { switch relayMode {
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request) baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, nil
default: default:
aliReq := requestOpenAI2Ali(*request) baiduRequest := requestOpenAI2Ali(*request)
return aliReq, nil return baiduRequest, nil
} }
} }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
aliRequest := oaiImage2Ali(request)
return aliRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
switch info.RelayMode { if info.IsStream {
case constant.RelayModeImagesGenerations: err, usage = aliStreamHandler(c, resp)
err, usage = aliImageHandler(c, resp, info) } else {
case constant.RelayModeEmbeddings: switch info.RelayMode {
err, usage = aliEmbeddingHandler(c, resp) case constant.RelayModeEmbeddings:
default: err, usage = aliEmbeddingHandler(c, resp)
if info.IsStream { default:
err, usage = openai.OaiStreamHandler(c, resp, info) err, usage = aliHandler(c, resp)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
} }
return return

View File

@ -36,8 +36,8 @@ type AliEmbeddingRequest struct {
} }
type AliEmbedding struct { type AliEmbedding struct {
Embedding any `json:"embedding"` Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"` TextIndex int `json:"text_index"`
} }
type AliEmbeddingResponse struct { type AliEmbeddingResponse struct {
@ -60,40 +60,13 @@ type AliUsage struct {
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
} }
type TaskResult struct {
B64Image string `json:"b64_image,omitempty"`
Url string `json:"url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
type AliOutput struct { type AliOutput struct {
TaskId string `json:"task_id,omitempty"` Text string `json:"text"`
TaskStatus string `json:"task_status,omitempty"` FinishReason string `json:"finish_reason"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
Message string `json:"message,omitempty"`
Code string `json:"code,omitempty"`
Results []TaskResult `json:"results,omitempty"`
} }
type AliResponse struct { type AliChatResponse struct {
Output AliOutput `json:"output"` Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"` Usage AliUsage `json:"usage"`
AliError AliError
} }
type AliImageRequest struct {
Model string `json:"model"`
Input struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
} `json:"input"`
Parameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}

View File

@ -1,177 +0,0 @@
package ali
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
var imageRequest AliImageRequest
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = request.N
imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest
}
func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("/api/v1/tasks/%s", taskID)
var aliResponse AliResponse
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+key)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
common.SysError("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
var response AliResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
common.SysError("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
return &response, nil, responseBody
}
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) {
waitSeconds := 3
step := 0
maxStep := 20
var taskResponse AliResponse
var responseBody []byte
for {
step++
rsp, err, body := updateTask(info, taskID, key)
responseBody = body
if err != nil {
return &taskResponse, responseBody, err
}
if rsp.Output.TaskStatus == "" {
return &taskResponse, responseBody, nil
}
switch rsp.Output.TaskStatus {
case "FAILED":
fallthrough
case "CANCELED":
fallthrough
case "SUCCEEDED":
fallthrough
case "UNKNOWN":
return rsp, responseBody, nil
}
if step >= maxStep {
break
}
time.Sleep(time.Duration(waitSeconds) * time.Second)
}
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
}
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
imageResponse := dto.ImageResponse{
Created: info.StartTime.Unix(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := common.GetImageFromUrl(data.Url)
if err != nil {
common.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
} else {
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return &imageResponse
}
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
}
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey)
if err != nil {
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, nil
}

View File

@ -16,13 +16,34 @@ import (
const EnableSearchModelSuffix = "-internet" const EnableSearchModelSuffix = "-internet"
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
if request.TopP >= 1 { messages := make([]AliMessage, 0, len(request.Messages))
request.TopP = 0.999 //prompt := ""
} else if request.TopP <= 0 { for i := 0; i < len(request.Messages); i++ {
request.TopP = 0.001 message := request.Messages[i]
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
//Prompt: prompt,
Messages: messages,
},
Parameters: AliParameters{
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
EnableSearch: enableSearch,
},
} }
return &request
} }
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest { func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
@ -89,7 +110,7 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe
return &openAIEmbeddingResponse return &openAIEmbeddingResponse
} }
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse { func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
content, _ := json.Marshal(response.Output.Text) content, _ := json.Marshal(response.Output.Text)
choice := dto.OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: 0, Index: 0,
@ -113,9 +134,9 @@ func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
return &fullTextResponse return &fullTextResponse
} }
func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse { func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(aliResponse.Output.Text) choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" { if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason choice.FinishReason = &finishReason
@ -133,7 +154,18 @@ func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStre
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage dto.Usage var usage dto.Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string) dataChan := make(chan string)
stopChan := make(chan bool) stopChan := make(chan bool)
go func() { go func() {
@ -155,7 +187,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
var aliResponse AliResponse var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse) err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
@ -167,7 +199,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
} }
response := streamResponseAli2OpenAI(&aliResponse) response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText)) response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
@ -189,7 +221,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
} }
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var aliResponse AliResponse var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

View File

@ -7,19 +7,14 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/relay/common" "one-api/relay/common"
"one-api/relay/constant"
"one-api/service" "one-api/service"
) )
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) { func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
// multipart/form-data req.Header.Set("Accept", c.Request.Header.Get("Accept"))
} else { if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
} }
} }
@ -43,29 +38,6 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
return resp, nil return resp, nil
} }
func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
// set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
err = a.SetupRequestHeader(c, req, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := service.GetHttpClient().Do(req) resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {
@ -78,27 +50,3 @@ func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
_ = c.Request.Body.Close() _ = c.Request.Body.Close()
return resp, nil return resp, nil
} }
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.BuildRequestURL(info)
if err != nil {
return nil, err
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(requestBody), nil
}
err = a.BuildRequestHeader(c, req, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}

View File

@ -1,85 +0,0 @@
package aws
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.RequestMode = RequestModeMessage
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
var claudeReq *claude.ClaudeRequest
var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
c.Set("request_model", request.Model)
c.Set("converted_request", claudeReq)
return claudeReq, err
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = awsHandler(c, info, a.RequestMode)
}
return
}
func (a *Adaptor) GetModelList() (models []string) {
for n := range awsModelIDMap {
models = append(models, n)
}
return
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@ -1,14 +0,0 @@
package aws
var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
}
var ChannelName = "aws"

View File

@ -1,19 +0,0 @@
package aws
import (
"one-api/relay/channel/claude"
)
type AwsClaudeRequest struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
System string `json:"system"`
Messages []claude.ClaudeMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}

View File

@ -1,232 +0,0 @@
package aws
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"io"
"net/http"
"one-api/common"
relaymodel "one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
)
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
awsSecret := strings.Split(info.ApiKey, "|")
if len(awsSecret) != 3 {
return nil, errors.New("invalid aws secret key")
}
ak := awsSecret[0]
sk := awsSecret[1]
region := awsSecret[2]
client := bedrockruntime.New(bedrockruntime.Options{
Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
})
return client, nil
}
func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
return &relaymodel.OpenAIErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.OpenAIError{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID, nil
}
return requestModel, nil
}
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(c.GetString("request_model"))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get("converted_request")
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*claude.ClaudeRequest)
awsClaudeReq := &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
}
claudeResponse := new(claude.ClaudeResponse)
err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
usage := relaymodel.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(c.GetString("request_model"))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get("converted_request")
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*claude.ClaudeRequest)
awsClaudeReq := &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
var id string
var model string
isFirst := true
createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
claudeResp := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return false
}
response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
if claudeUsage != nil {
usage.PromptTokens += claudeUsage.InputTokens
usage.CompletionTokens += claudeUsage.OutputTokens
}
if response == nil {
return true
}
if response.Id != "" {
id = response.Id
}
if response.Model != "" {
model = response.Model
}
response.Created = createdTime
response.Id = id
response.Model = model
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
default:
fmt.Println("union is nil or unknown type")
return false
}
})
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
if resp != nil {
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
}
return nil, &usage
}

View File

@ -2,7 +2,6 @@ package baidu
import ( import (
"errors" "errors"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
@ -10,85 +9,33 @@ import (
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"strings"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t var fullRequestURL string
suffix := "chat/"
if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
suffix = "embeddings/"
}
if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
suffix = "embeddings/"
}
if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
suffix = "embeddings/"
}
switch info.UpstreamModelName { switch info.UpstreamModelName {
case "ERNIE-4.0":
suffix += "completions_pro"
case "ERNIE-Bot-4": case "ERNIE-Bot-4":
suffix += "completions_pro" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-4.0-8K":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-3.5-8K-0205":
suffix += "ernie-3.5-8k-0205"
case "ERNIE-3.5-8K-1222":
suffix += "ernie-3.5-8k-1222"
case "ERNIE-Bot-8K": case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
case "ERNIE-3.5-4K-0205": case "ERNIE-Bot":
suffix += "ernie-3.5-4k-0205" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed-8K": case "ERNIE-Speed":
suffix += "ernie_speed" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Speed-128K": case "ERNIE-Bot-turbo":
suffix += "ernie-speed-128k" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Lite-8K-0922":
suffix += "eb-instant"
case "ERNIE-Lite-8K-0308":
suffix += "ernie-lite-8k"
case "ERNIE-Tiny-8K":
suffix += "ernie-tiny-8k"
case "BLOOMZ-7B": case "BLOOMZ-7B":
suffix += "bloomz_7b1" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1": case "Embedding-V1":
suffix += "embedding-v1" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
case "bge-large-zh":
suffix += "bge_large_zh"
case "bge-large-en":
suffix += "bge_large_en"
case "tao-8k":
suffix += "tao_8k"
default:
suffix += strings.ToLower(info.UpstreamModelName)
} }
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
var accessToken string var accessToken string
var err error var err error
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil { if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
@ -104,11 +51,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch info.RelayMode { switch relayMode {
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request) baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, nil
@ -118,15 +65,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
} }
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
err, usage = baiduStreamHandler(c, resp) err, usage = baiduStreamHandler(c, resp)
} else { } else {

View File

@ -1,22 +1,12 @@
package baidu package baidu
var ModelList = []string{ var ModelList = []string{
"ERNIE-4.0-8K", "ERNIE-Bot-4",
"ERNIE-3.5-8K",
"ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222",
"ERNIE-Bot-8K", "ERNIE-Bot-8K",
"ERNIE-3.5-4K-0205", "ERNIE-Bot",
"ERNIE-Speed-8K", "ERNIE-Speed",
"ERNIE-Speed-128K", "ERNIE-Bot-turbo",
"ERNIE-Lite-8K-0922",
"ERNIE-Lite-8K-0308",
"ERNIE-Tiny-8K",
"BLOOMZ-7B",
"Embedding-V1", "Embedding-V1",
"bge-large-zh",
"bge-large-en",
"tao-8k",
} }
var ChannelName = "baidu" var ChannelName = "baidu"

View File

@ -11,16 +11,9 @@ type BaiduMessage struct {
} }
type BaiduChatRequest struct { type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"` Messages []BaiduMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"` Stream bool `json:"stream"`
TopP float64 `json:"top_p,omitempty"` UserId string `json:"user_id,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
} }
type Error struct { type Error struct {
@ -50,9 +43,9 @@ type BaiduEmbeddingRequest struct {
} }
type BaiduEmbeddingData struct { type BaiduEmbeddingData struct {
Object string `json:"object"` Object string `json:"object"`
Embedding any `json:"embedding"` Embedding []float64 `json:"embedding"`
Index int `json:"index"` Index int `json:"index"`
} }
type BaiduEmbeddingResponse struct { type BaiduEmbeddingResponse struct {

View File

@ -22,33 +22,17 @@ import (
var baiduTokenStore sync.Map var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest { func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
baiduRequest := BaiduChatRequest{ messages := make([]BaiduMessage, 0, len(request.Messages))
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
UserId: request.User,
}
if request.MaxTokens != 0 {
maxTokens := int(request.MaxTokens)
if request.MaxTokens == 1 {
maxTokens = 2
}
baiduRequest.MaxOutputTokens = &maxTokens
}
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" { messages = append(messages, BaiduMessage{
baiduRequest.System = message.StringContent() Role: message.Role,
} else { Content: message.StringContent(),
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{ })
Role: message.Role, }
Content: message.StringContent(), return &BaiduChatRequest{
}) Messages: messages,
} Stream: request.Stream,
} }
return &baiduRequest
} }
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse { func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
@ -73,7 +57,7 @@ func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse { func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(baiduResponse.Result) choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd { if baiduResponse.IsEnd {
choice.FinishReason = &relaycommon.StopFinishReason choice.FinishReason = &relaycommon.StopFinishReason
} }

View File

@ -21,17 +21,7 @@ type Adaptor struct {
RequestMode int RequestMode int
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") { if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage a.RequestMode = RequestModeMessage
} else { } else {
@ -50,45 +40,34 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req) channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("x-api-key", info.ApiKey) req.Header.Set("x-api-key", info.ApiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version") anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" { if anthropicVersion == "" {
anthropicVersion = "2023-06-01" anthropicVersion = "2023-06-01"
} }
req.Header.Set("anthropic-version", anthropicVersion) req.Header.Set("anthropic-version", anthropicVersion)
anthropicBeta := c.Request.Header.Get("anthropic-beta")
if "" != anthropicBeta {
req.Header.Set("anthropic-beta", anthropicBeta)
}
return nil return nil
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
if a.RequestMode == RequestModeCompletion { if a.RequestMode == RequestModeCompletion {
return RequestOpenAI2ClaudeComplete(c, *request), nil return requestOpenAI2ClaudeComplete(*request), nil
} else { } else {
return RequestOpenAI2ClaudeMessage(c, *request) return requestOpenAI2ClaudeMessage(*request)
} }
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
} else { } else {
err, usage = ClaudeHandler(c, resp, a.RequestMode, info) err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

Some files were not shown because too many files have changed in this diff Show More