mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-18 11:33:42 +08:00
Compare commits
35 Commits
v0.2.7.1-a
...
v0.2.6-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
811019bf5c | ||
|
|
0ed6600437 | ||
|
|
1fe7f14d57 | ||
|
|
a7bafec1bf | ||
|
|
ed951b3974 | ||
|
|
c74e43b8fd | ||
|
|
03bd9b0cc4 | ||
|
|
149902bd8a | ||
|
|
cd3ed22045 | ||
|
|
2329d387ca | ||
|
|
7d18a8e2a9 | ||
|
|
80af3718d0 | ||
|
|
77ea6bec46 | ||
|
|
c0ab8ae953 | ||
|
|
923c2dee32 | ||
|
|
ea17a46d8e | ||
|
|
bfe9e5d25a | ||
|
|
831ff47254 | ||
|
|
11eaba6b5d | ||
|
|
c2e4ec25c8 | ||
|
|
8537f10412 | ||
|
|
71d60eeef7 | ||
|
|
247ae0988f | ||
|
|
0907fa6994 | ||
|
|
9855343aa8 | ||
|
|
bd50fde268 | ||
|
|
4267de5642 | ||
|
|
8b55116563 | ||
|
|
f35e63e3f3 | ||
|
|
17c409de23 | ||
|
|
e4753e7411 | ||
|
|
9adefa80b9 | ||
|
|
4ce2381182 | ||
|
|
62afc21ea5 | ||
|
|
7ddb7c586d |
6
.github/ISSUE_TEMPLATE/config.yml
vendored
6
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,5 @@
|
|||||||
blank_issues_enabled: false
|
blank_issues_enabled: false
|
||||||
contact_links:
|
contact_links:
|
||||||
- name: 项目群聊
|
- name: 交流社区
|
||||||
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
|
url: https://linux.do
|
||||||
about: QQ 群:629454374
|
about: 项目交流社区
|
||||||
|
|||||||
5
.github/workflows/docker-image-amd64.yml
vendored
5
.github/workflows/docker-image-amd64.yml
vendored
@@ -1,9 +1,6 @@
|
|||||||
name: Publish Docker image (amd64)
|
name: Publish Docker image (amd64)
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- '*'
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
name:
|
name:
|
||||||
@@ -42,7 +39,7 @@ jobs:
|
|||||||
uses: docker/metadata-action@v4
|
uses: docker/metadata-action@v4
|
||||||
with:
|
with:
|
||||||
images: |
|
images: |
|
||||||
calciumion/new-api
|
pengzhile/new-api
|
||||||
ghcr.io/${{ github.repository }}
|
ghcr.io/${{ github.repository }}
|
||||||
|
|
||||||
- name: Build and push Docker images
|
- name: Build and push Docker images
|
||||||
|
|||||||
2
.github/workflows/docker-image-arm64.yml
vendored
2
.github/workflows/docker-image-arm64.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
|||||||
uses: docker/metadata-action@v4
|
uses: docker/metadata-action@v4
|
||||||
with:
|
with:
|
||||||
images: |
|
images: |
|
||||||
calciumion/new-api
|
pengzhile/new-api
|
||||||
ghcr.io/${{ github.repository }}
|
ghcr.io/${{ github.repository }}
|
||||||
|
|
||||||
- name: Build and push Docker images
|
- name: Build and push Docker images
|
||||||
|
|||||||
214
LICENSE
214
LICENSE
@@ -1,201 +1,21 @@
|
|||||||
Apache License
|
MIT License
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
Copyright (c) 2024 Calcium-Ion
|
||||||
|
|
||||||
1. Definitions.
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
The above copyright notice and this permission notice shall be included in all
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
the copyright owner that is granting the License.
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
other entities that control, are controlled by, or are under common
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
control with that entity. For the purposes of this definition,
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
SOFTWARE.
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ all: build-frontend start-backend
|
|||||||
|
|
||||||
build-frontend:
|
build-frontend:
|
||||||
@echo "Building frontend..."
|
@echo "Building frontend..."
|
||||||
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
|
@cd $(FRONTEND_DIR) && yarn install --network-timeout 1000000 && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) yarn build
|
||||||
|
|
||||||
start-backend:
|
start-backend:
|
||||||
@echo "Starting backend dev server..."
|
@echo "Starting backend dev server..."
|
||||||
@@ -2,21 +2,6 @@
|
|||||||
|
|
||||||
**简介**:Midjourney Proxy API文档
|
**简介**:Midjourney Proxy API文档
|
||||||
|
|
||||||
## 接口列表
|
|
||||||
支持的接口如下:
|
|
||||||
+ [x] /mj/submit/imagine
|
|
||||||
+ [x] /mj/submit/change
|
|
||||||
+ [x] /mj/submit/blend
|
|
||||||
+ [x] /mj/submit/describe
|
|
||||||
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
|
||||||
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
|
||||||
+ [x] /task/list-by-condition
|
|
||||||
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
|
||||||
+ [x] /mj/submit/modal
|
|
||||||
+ [x] /mj/submit/shorten
|
|
||||||
+ [x] /mj/task/{id}/image-seed
|
|
||||||
+ [x] /mj/insight-face/swap (InsightFace)
|
|
||||||
|
|
||||||
## 模型列表
|
## 模型列表
|
||||||
|
|
||||||
### midjourney-proxy支持
|
### midjourney-proxy支持
|
||||||
@@ -72,11 +57,11 @@
|
|||||||
|
|
||||||
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
|
2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
|
||||||
,模型请参考上方模型列表
|
,模型请参考上方模型列表
|
||||||
3. **代理**填写midjourney-proxy部署的地址,例如:http://localhost:8080
|
3. 地址填写midjourney-proxy部署的地址,例如:http://localhost:8080
|
||||||
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
|
4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
|
||||||
|
|
||||||
### 对接上游new api
|
### 对接上游new api
|
||||||
|
|
||||||
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
|
1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
|
||||||
2. **代理**填写上游new api的地址,例如:http://localhost:3000
|
2. 地址填写上游new api的地址,例如:http://localhost:3000
|
||||||
3. 密钥填写上游new api的密钥
|
3. 密钥填写上游new api的密钥
|
||||||
50
README.md
50
README.md
@@ -5,6 +5,8 @@
|
|||||||
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
|
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
|
||||||
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||||
|
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
|
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
|
||||||
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||||
|
|
||||||
@@ -16,7 +18,19 @@
|
|||||||
此分叉版本的主要变更如下:
|
此分叉版本的主要变更如下:
|
||||||
|
|
||||||
1. 全新的UI界面(部分界面还待更新)
|
1. 全新的UI界面(部分界面还待更新)
|
||||||
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
|
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md),支持的接口如下:
|
||||||
|
+ [x] /mj/submit/imagine
|
||||||
|
+ [x] /mj/submit/change
|
||||||
|
+ [x] /mj/submit/blend
|
||||||
|
+ [x] /mj/submit/describe
|
||||||
|
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
|
||||||
|
+ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
|
||||||
|
+ [x] /task/list-by-condition
|
||||||
|
+ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
|
||||||
|
+ [x] /mj/submit/modal
|
||||||
|
+ [x] /mj/submit/shorten
|
||||||
|
+ [x] /mj/task/{id}/image-seed
|
||||||
|
+ [x] /mj/insight-face/swap (InsightFace)
|
||||||
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
|
||||||
+ [x] 易支付
|
+ [x] 易支付
|
||||||
4. 支持用key查询使用额度:
|
4. 支持用key查询使用额度:
|
||||||
@@ -33,51 +47,29 @@
|
|||||||
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
|
||||||
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
3. 选择你的bot,然后输入http(s)://你的网站地址/login
|
||||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||||
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
|
|
||||||
14. 支持Rerank模型,目前仅兼容Cohere和Jina,可接入Dify,[对接文档](Rerank.md)
|
|
||||||
|
|
||||||
## 模型支持
|
## 模型支持
|
||||||
此版本额外支持以下模型:
|
此版本额外支持以下模型:
|
||||||
1. 第三方模型 **gps** (gpt-4-gizmo-*)
|
1. 第三方模型 **gps** (gpt-4-gizmo-*)
|
||||||
2. 智谱glm-4v,glm-4v识图
|
2. 智谱glm-4v,glm-4v识图
|
||||||
3. Anthropic Claude 3
|
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
|
||||||
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
||||||
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
|
||||||
6. [零一万物](https://platform.lingyiwanwu.com/)
|
6. [零一万物](https://platform.lingyiwanwu.com/)
|
||||||
7. 自定义渠道,支持填入完整调用地址
|
|
||||||
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
|
||||||
9. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md)
|
|
||||||
10. Dify
|
|
||||||
|
|
||||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||||
|
|
||||||
## 渠道重试
|
## 渠道重试
|
||||||
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
|
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。
|
||||||
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
|
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
|
||||||
### 缓存设置方法
|
### 缓存设置方法
|
||||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||||
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。
|
||||||
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
+ 例子:`MEMORY_CACHE_ENABLED=true`
|
||||||
### 为什么有的时候没有重试
|
|
||||||
这些错误码不会重试:400,504,524
|
|
||||||
### 我想让400也重试
|
|
||||||
在`渠道->编辑`中,将`状态码复写`改为
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"400": "500"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
可以实现400错误转为500错误,从而重试
|
|
||||||
|
|
||||||
## 比原版One API多出的配置
|
|
||||||
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
|
|
||||||
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`, 可选值为 `true` 和 `false`
|
|
||||||
- `FORCE_STREAM_OPTION`:覆盖客户端stream_options参数,请求上游返回流模式usage,目前仅支持 `OpenAI` 渠道类型
|
|
||||||
## 部署
|
## 部署
|
||||||
### 部署要求
|
|
||||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
|
||||||
- 远程数据库:MySQL 版本 >= 5.7.8,PgSQL 版本 >= 9.6
|
|
||||||
### 基于 Docker 进行部署
|
### 基于 Docker 进行部署
|
||||||
```shell
|
```shell
|
||||||
# 使用 SQLite 的部署命令:
|
# 使用 SQLite 的部署命令:
|
||||||
@@ -96,15 +88,9 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -
|
|||||||
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
|
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
|
||||||
# 注意:数据库要开启远程访问,并且只允许服务器IP访问
|
# 注意:数据库要开启远程访问,并且只允许服务器IP访问
|
||||||
```
|
```
|
||||||
### 默认账号密码
|
|
||||||
默认账号root 密码123456
|
|
||||||
|
|
||||||
## Midjourney接口设置文档
|
## Midjourney接口设置文档
|
||||||
[对接文档](Midjourney.md)
|
[对接文档](Midjourney.md)
|
||||||
|
|
||||||
## Suno接口设置文档
|
|
||||||
[对接文档](Suno.md)
|
|
||||||
|
|
||||||
## 交流群
|
## 交流群
|
||||||
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
|
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
|
||||||
|
|
||||||
|
|||||||
62
Rerank.md
62
Rerank.md
@@ -1,62 +0,0 @@
|
|||||||
# Rerank API文档
|
|
||||||
|
|
||||||
**简介**:Rerank API文档
|
|
||||||
|
|
||||||
## 接入Dify
|
|
||||||
模型供应商选择Jina,按要求填写模型信息即可接入Dify。
|
|
||||||
|
|
||||||
## 请求方式
|
|
||||||
|
|
||||||
Post: /v1/rerank
|
|
||||||
|
|
||||||
Request:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"model": "rerank-multilingual-v3.0",
|
|
||||||
"query": "What is the capital of the United States?",
|
|
||||||
"top_n": 3,
|
|
||||||
"documents": [
|
|
||||||
"Carson City is the capital city of the American state of Nevada.",
|
|
||||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
|
|
||||||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
|
|
||||||
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
|
|
||||||
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Response:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"results": [
|
|
||||||
{
|
|
||||||
"document": {
|
|
||||||
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
|
|
||||||
},
|
|
||||||
"index": 2,
|
|
||||||
"relevance_score": 0.9999702
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"document": {
|
|
||||||
"text": "Carson City is the capital city of the American state of Nevada."
|
|
||||||
},
|
|
||||||
"index": 0,
|
|
||||||
"relevance_score": 0.67800725
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"document": {
|
|
||||||
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
|
|
||||||
},
|
|
||||||
"index": 3,
|
|
||||||
"relevance_score": 0.02800752
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 158,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 158
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
44
Suno.md
44
Suno.md
@@ -1,44 +0,0 @@
|
|||||||
# Suno API文档
|
|
||||||
|
|
||||||
**简介**:Suno API文档
|
|
||||||
|
|
||||||
## 接口列表
|
|
||||||
支持的接口如下:
|
|
||||||
+ [x] /suno/submit/music
|
|
||||||
+ [x] /suno/submit/lyrics
|
|
||||||
+ [x] /suno/fetch
|
|
||||||
+ [x] /suno/fetch/:id
|
|
||||||
|
|
||||||
## 模型列表
|
|
||||||
|
|
||||||
### Suno API支持
|
|
||||||
|
|
||||||
- suno_music (自定义模式、灵感模式、续写)
|
|
||||||
- suno_lyrics (生成歌词)
|
|
||||||
|
|
||||||
|
|
||||||
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"suno_music": 0.3,
|
|
||||||
"suno_lyrics": 0.01
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 渠道设置
|
|
||||||
|
|
||||||
### 对接 Suno API
|
|
||||||
|
|
||||||
1.
|
|
||||||
部署 Suno API,并配置好suno账号等(强烈建议设置密钥),[项目地址](https://github.com/Suno-API/Suno-API)
|
|
||||||
|
|
||||||
2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
|
|
||||||
,模型请参考上方模型列表
|
|
||||||
3. **代理**填写 Suno API 部署的地址,例如:http://localhost:8080
|
|
||||||
4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
|
|
||||||
|
|
||||||
### 对接上游new api
|
|
||||||
|
|
||||||
1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
|
|
||||||
2. **代理**填写上游new api的地址,例如:http://localhost:3000
|
|
||||||
3. 密钥填写上游new api的密钥
|
|
||||||
@@ -9,9 +9,19 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Pay Settings
|
||||||
|
|
||||||
|
var StripeApiSecret = ""
|
||||||
|
var StripeWebhookSecret = ""
|
||||||
|
var StripePriceId = ""
|
||||||
|
var PaymentEnabled = false
|
||||||
|
var StripeUnitPrice = 8.0
|
||||||
|
var MinTopUp = 5
|
||||||
|
|
||||||
var StartTime = time.Now().Unix() // unit: second
|
var StartTime = time.Now().Unix() // unit: second
|
||||||
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
||||||
var SystemName = "New API"
|
var SystemName = "New API"
|
||||||
|
var ServerAddress = "http://localhost:3000"
|
||||||
var Footer = ""
|
var Footer = ""
|
||||||
var Logo = ""
|
var Logo = ""
|
||||||
var TopUpLink = ""
|
var TopUpLink = ""
|
||||||
@@ -21,7 +31,6 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
|||||||
var DisplayInCurrencyEnabled = true
|
var DisplayInCurrencyEnabled = true
|
||||||
var DisplayTokenStatEnabled = true
|
var DisplayTokenStatEnabled = true
|
||||||
var DrawingEnabled = true
|
var DrawingEnabled = true
|
||||||
var TaskEnabled = true
|
|
||||||
var DataExportEnabled = true
|
var DataExportEnabled = true
|
||||||
var DataExportInterval = 5 // unit: minute
|
var DataExportInterval = 5 // unit: minute
|
||||||
var DataExportDefaultTime = "hour" // unit: minute
|
var DataExportDefaultTime = "hour" // unit: minute
|
||||||
@@ -41,10 +50,12 @@ var PasswordLoginEnabled = true
|
|||||||
var PasswordRegisterEnabled = true
|
var PasswordRegisterEnabled = true
|
||||||
var EmailVerificationEnabled = false
|
var EmailVerificationEnabled = false
|
||||||
var GitHubOAuthEnabled = false
|
var GitHubOAuthEnabled = false
|
||||||
|
var LinuxDoOAuthEnabled = false
|
||||||
var WeChatAuthEnabled = false
|
var WeChatAuthEnabled = false
|
||||||
var TelegramOAuthEnabled = false
|
var TelegramOAuthEnabled = false
|
||||||
var TurnstileCheckEnabled = false
|
var TurnstileCheckEnabled = false
|
||||||
var RegisterEnabled = true
|
var RegisterEnabled = true
|
||||||
|
var UserSelfDeletionEnabled = false
|
||||||
|
|
||||||
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
|
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
|
||||||
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
|
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
|
||||||
@@ -75,6 +86,10 @@ var SMTPToken = ""
|
|||||||
var GitHubClientId = ""
|
var GitHubClientId = ""
|
||||||
var GitHubClientSecret = ""
|
var GitHubClientSecret = ""
|
||||||
|
|
||||||
|
var LinuxDoClientId = ""
|
||||||
|
var LinuxDoClientSecret = ""
|
||||||
|
var LinuxDoMinLevel = 0
|
||||||
|
|
||||||
var WeChatServerAddress = ""
|
var WeChatServerAddress = ""
|
||||||
var WeChatServerToken = ""
|
var WeChatServerToken = ""
|
||||||
var WeChatAccountQRCodeImageURL = ""
|
var WeChatAccountQRCodeImageURL = ""
|
||||||
@@ -103,14 +118,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
|||||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||||
|
|
||||||
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
||||||
|
|
||||||
var BatchUpdateEnabled = false
|
var BatchUpdateEnabled = false
|
||||||
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
||||||
|
|
||||||
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RequestIdKey = "X-Oneapi-Request-Id"
|
RequestIdKey = "X-Oneapi-Request-Id"
|
||||||
@@ -133,10 +148,10 @@ var (
|
|||||||
// All duration's unit is seconds
|
// All duration's unit is seconds
|
||||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||||
var (
|
var (
|
||||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||||
GlobalApiRateLimitDuration int64 = 3 * 60
|
GlobalApiRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||||
GlobalWebRateLimitDuration int64 = 3 * 60
|
GlobalWebRateLimitDuration int64 = 3 * 60
|
||||||
|
|
||||||
UploadRateLimitNum = 10
|
UploadRateLimitNum = 10
|
||||||
@@ -176,6 +191,12 @@ const (
|
|||||||
ChannelStatusAutoDisabled = 3
|
ChannelStatusAutoDisabled = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TopUpStatusPending = "pending"
|
||||||
|
TopUpStatusSuccess = "success"
|
||||||
|
TopUpStatusExpired = "expired"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ChannelTypeUnknown = 0
|
ChannelTypeUnknown = 0
|
||||||
ChannelTypeOpenAI = 1
|
ChannelTypeOpenAI = 1
|
||||||
@@ -208,13 +229,6 @@ const (
|
|||||||
ChannelTypeLingYiWanWu = 31
|
ChannelTypeLingYiWanWu = 31
|
||||||
ChannelTypeAws = 33
|
ChannelTypeAws = 33
|
||||||
ChannelTypeCohere = 34
|
ChannelTypeCohere = 34
|
||||||
ChannelTypeMiniMax = 35
|
|
||||||
ChannelTypeSunoAPI = 36
|
|
||||||
ChannelTypeDify = 37
|
|
||||||
ChannelTypeJina = 38
|
|
||||||
|
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@@ -253,8 +267,4 @@ var ChannelBaseURLs = []string{
|
|||||||
"", //32
|
"", //32
|
||||||
"", //33
|
"", //33
|
||||||
"https://api.cohere.ai", //34
|
"https://api.cohere.ai", //34
|
||||||
"https://api.minimax.chat", //35
|
|
||||||
"", //36
|
|
||||||
"", //37
|
|
||||||
"https://api.jina.ai", //38
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
func GetEnvOrDefault(env string, defaultValue int) int {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
num, err := strconv.Atoi(os.Getenv(env))
|
|
||||||
if err != nil {
|
|
||||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return num
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetEnvOrDefaultString(env string, defaultValue string) string {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return os.Getenv(env)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
|
|
||||||
if env == "" || os.Getenv(env) == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
b, err := strconv.ParseBool(os.Getenv(env))
|
|
||||||
if err != nil {
|
|
||||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,6 @@ package common
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func SafeGoroutine(f func()) {
|
func SafeGoroutine(f func()) {
|
||||||
@@ -46,21 +45,3 @@ func SafeSendString(ch chan string, value string) (closed bool) {
|
|||||||
// If the code reaches here, then the channel was not closed.
|
// If the code reaches here, then the channel was not closed.
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// SafeSendStringTimeout send, return true, else return false
|
|
||||||
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
|
|
||||||
defer func() {
|
|
||||||
// Recover from panic if one occured. A panic would mean the channel was closed.
|
|
||||||
if recover() != nil {
|
|
||||||
closed = false
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// This will panic if the channel is closed.
|
|
||||||
select {
|
|
||||||
case ch <- value:
|
|
||||||
return true
|
|
||||||
case <-time.After(time.Duration(timeout) * time.Second):
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import "encoding/json"
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
var GroupRatio = map[string]float64{
|
var GroupRatio = map[string]float64{
|
||||||
"default": 1,
|
"default": 1,
|
||||||
|
|||||||
84
common/hash.go
Normal file
84
common/hash.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Sha256Raw(data string) []byte {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(data))
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Sha1Raw(data []byte) []byte {
|
||||||
|
h := sha1.New()
|
||||||
|
h.Write([]byte(data))
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Sha1(data string) string {
|
||||||
|
return hex.EncodeToString(Sha1Raw([]byte(data)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func HmacSha256Raw(message, key []byte) []byte {
|
||||||
|
h := hmac.New(sha256.New, key)
|
||||||
|
h.Write(message)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func HmacSha256(message, key string) string {
|
||||||
|
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func RandomBytes(length int) []byte {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
b := make([]byte, length)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func RandomString(length int) string {
|
||||||
|
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
result := make([]byte, length)
|
||||||
|
randomBytes := RandomBytes(length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RandomHex(length int) string {
|
||||||
|
const chars = "abcdef0123456789"
|
||||||
|
result := make([]byte, length)
|
||||||
|
randomBytes := RandomBytes(length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RandomNumber(length int) string {
|
||||||
|
const chars = "0123456789"
|
||||||
|
result := make([]byte, length)
|
||||||
|
randomBytes := RandomBytes(length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RandomUUID() string {
|
||||||
|
all := RandomHex(32)
|
||||||
|
return all[:8] + "-" + all[8:12] + "-" + all[12:16] + "-" + all[16:20] + "-" + all[20:]
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"golang.org/x/image/webp"
|
"golang.org/x/image/webp"
|
||||||
"image"
|
"image"
|
||||||
"io"
|
"io"
|
||||||
"one-api/common"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,13 +31,25 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
|
|||||||
return config, format, base64String, err
|
return config, format, base64String, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
func IsImageUrl(url string) (bool, error) {
|
||||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
resp, err := http.Head(url)
|
||||||
resp, err := DoImageRequest(url)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return false, err
|
||||||
}
|
}
|
||||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||||
|
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||||
|
isImage, err := IsImageUrl(url)
|
||||||
|
if !isImage {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@@ -52,21 +64,16 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||||
response, err := DoImageRequest(imageUrl)
|
response, err := http.Get(imageUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||||
return image.Config{}, "", err
|
return image.Config{}, "", err
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
|
||||||
if response.StatusCode != 200 {
|
|
||||||
err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status))
|
|
||||||
return image.Config{}, "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var readData []byte
|
var readData []byte
|
||||||
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
|
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
|
||||||
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
||||||
|
|
||||||
// 从response.Body读取更多的数据直到达到当前的限制
|
// 从response.Body读取更多的数据直到达到当前的限制
|
||||||
additionalData := make([]byte, limit-int64(len(readData)))
|
additionalData := make([]byte, limit-int64(len(readData)))
|
||||||
@@ -92,11 +99,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
|
|||||||
config, format, err := image.DecodeConfig(reader)
|
config, format, err := image.DecodeConfig(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||||
common.SysLog(err.Error())
|
SysLog(err.Error())
|
||||||
config, err = webp.DecodeConfig(reader)
|
config, err = webp.DecodeConfig(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||||
common.SysLog(err.Error())
|
SysLog(err.Error())
|
||||||
}
|
}
|
||||||
format = "webp"
|
format = "webp"
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,6 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
@@ -100,12 +99,10 @@ func LogQuota(quota int) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogJson 仅供测试使用 only for test
|
func LogQuotaF(quota float64) string {
|
||||||
func LogJson(ctx context.Context, msg string, obj any) {
|
if DisplayInCurrencyEnabled {
|
||||||
jsonStr, err := json.Marshal(obj)
|
return fmt.Sprintf("$%.6f 额度", quota/QuotaPerUnit)
|
||||||
if err != nil {
|
} else {
|
||||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
return fmt.Sprintf("%d 点额度", int64(quota))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,13 +5,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// from songquanpeng/one-api
|
|
||||||
const (
|
|
||||||
USD2RMB = 7.3 // 暂定 1 USD = 7.3 RMB
|
|
||||||
USD = 500 // $0.002 = 1 -> $1 = 500
|
|
||||||
RMB = USD / USD2RMB
|
|
||||||
)
|
|
||||||
|
|
||||||
// modelRatio
|
// modelRatio
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
||||||
@@ -20,128 +13,99 @@ const (
|
|||||||
// 1 === $0.002 / 1K tokens
|
// 1 === $0.002 / 1K tokens
|
||||||
// 1 === ¥0.014 / 1k tokens
|
// 1 === ¥0.014 / 1k tokens
|
||||||
|
|
||||||
var defaultModelRatio = map[string]float64{
|
var DefaultModelRatio = map[string]float64{
|
||||||
//"midjourney": 50,
|
//"midjourney": 50,
|
||||||
"gpt-4-gizmo-*": 15,
|
"gpt-4-gizmo-*": 15,
|
||||||
"gpt-4-all": 15,
|
"gpt-4": 15,
|
||||||
"gpt-4o-all": 15,
|
"gpt-4-0314": 15,
|
||||||
"gpt-4": 15,
|
"gpt-4-0613": 15,
|
||||||
//"gpt-4-0314": 15, //deprecated
|
"gpt-4-32k": 30,
|
||||||
"gpt-4-0613": 15,
|
"gpt-4-32k-0314": 30,
|
||||||
"gpt-4-32k": 30,
|
|
||||||
//"gpt-4-32k-0314": 30, //deprecated
|
|
||||||
"gpt-4-32k-0613": 30,
|
"gpt-4-32k-0613": 30,
|
||||||
|
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||||
|
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||||
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
|
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
|
||||||
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
||||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||||
"gpt-4o": 2.5, // $0.01 / 1K tokens
|
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
|
||||||
"gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens
|
"gpt-3.5-turbo-0301": 0.75,
|
||||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
"gpt-3.5-turbo-0613": 0.75,
|
||||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||||
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens
|
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||||
//"gpt-3.5-turbo-0301": 0.75, //deprecated
|
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||||
"gpt-3.5-turbo-0613": 0.75,
|
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
"gpt-3.5-turbo-0125": 0.25,
|
||||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
||||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
"davinci-002": 1, // $0.002 / 1K tokens
|
||||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
"text-ada-001": 0.2,
|
||||||
"gpt-3.5-turbo-0125": 0.25,
|
"text-babbage-001": 0.25,
|
||||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
"text-curie-001": 1,
|
||||||
"davinci-002": 1, // $0.002 / 1K tokens
|
"text-davinci-002": 10,
|
||||||
"text-ada-001": 0.2,
|
"text-davinci-003": 10,
|
||||||
"text-babbage-001": 0.25,
|
"text-davinci-edit-001": 10,
|
||||||
"text-curie-001": 1,
|
"code-davinci-edit-001": 10,
|
||||||
//"text-davinci-002": 10,
|
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
||||||
//"text-davinci-003": 10,
|
"tts-1": 7.5, // 1k characters -> $0.015
|
||||||
"text-davinci-edit-001": 10,
|
"tts-1-1106": 7.5, // 1k characters -> $0.015
|
||||||
"code-davinci-edit-001": 10,
|
"tts-1-hd": 15, // 1k characters -> $0.03
|
||||||
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
"tts-1-hd-1106": 15, // 1k characters -> $0.03
|
||||||
"tts-1": 7.5, // 1k characters -> $0.015
|
"davinci": 10,
|
||||||
"tts-1-1106": 7.5, // 1k characters -> $0.015
|
"curie": 10,
|
||||||
"tts-1-hd": 15, // 1k characters -> $0.03
|
"babbage": 10,
|
||||||
"tts-1-hd-1106": 15, // 1k characters -> $0.03
|
"ada": 10,
|
||||||
"davinci": 10,
|
"text-embedding-3-small": 0.01,
|
||||||
"curie": 10,
|
"text-embedding-3-large": 0.065,
|
||||||
"babbage": 10,
|
"text-embedding-ada-002": 0.05,
|
||||||
"ada": 10,
|
"text-search-ada-doc-001": 10,
|
||||||
"text-embedding-3-small": 0.01,
|
"text-moderation-stable": 0.1,
|
||||||
"text-embedding-3-large": 0.065,
|
"text-moderation-latest": 0.1,
|
||||||
"text-embedding-ada-002": 0.05,
|
"dall-e-2": 8,
|
||||||
"text-search-ada-doc-001": 10,
|
"dall-e-3": 16,
|
||||||
"text-moderation-stable": 0.1,
|
"claude-instant-1": 0.4, // $0.8 / 1M tokens
|
||||||
"text-moderation-latest": 0.1,
|
"claude-2.0": 4, // $8 / 1M tokens
|
||||||
"claude-instant-1": 0.4, // $0.8 / 1M tokens
|
"claude-2.1": 4, // $8 / 1M tokens
|
||||||
"claude-2.0": 4, // $8 / 1M tokens
|
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||||
"claude-2.1": 4, // $8 / 1M tokens
|
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"claude-3-5-sonnet-20240620": 1.5,
|
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||||
"ERNIE-4.0-8K": 0.120 * RMB,
|
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"ERNIE-3.5-8K": 0.012 * RMB,
|
"PaLM-2": 1,
|
||||||
"ERNIE-3.5-8K-0205": 0.024 * RMB,
|
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
"ERNIE-3.5-8K-1222": 0.012 * RMB,
|
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
"ERNIE-Bot-8K": 0.024 * RMB,
|
"gemini-1.0-pro-vision-001": 1,
|
||||||
"ERNIE-3.5-4K-0205": 0.012 * RMB,
|
"gemini-1.0-pro-001": 1,
|
||||||
"ERNIE-Speed-8K": 0.004 * RMB,
|
"gemini-1.5-pro-latest": 1,
|
||||||
"ERNIE-Speed-128K": 0.004 * RMB,
|
"gemini-1.0-pro-latest": 1,
|
||||||
"ERNIE-Lite-8K-0922": 0.008 * RMB,
|
"gemini-1.0-pro-vision-latest": 1,
|
||||||
"ERNIE-Lite-8K-0308": 0.003 * RMB,
|
"gemini-ultra": 1,
|
||||||
"ERNIE-Tiny-8K": 0.001 * RMB,
|
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||||
"BLOOMZ-7B": 0.004 * RMB,
|
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||||
"Embedding-V1": 0.002 * RMB,
|
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||||
"bge-large-zh": 0.002 * RMB,
|
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"bge-large-en": 0.002 * RMB,
|
"glm-4": 7.143, // ¥0.1 / 1k tokens
|
||||||
"tao-8k": 0.002 * RMB,
|
"glm-4v": 7.143, // ¥0.1 / 1k tokens
|
||||||
"PaLM-2": 1,
|
"glm-3-turbo": 0.3572,
|
||||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"qwen-plus": 10, // ¥0.14 / 1k tokens
|
||||||
"gemini-1.0-pro-vision-001": 1,
|
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||||
"gemini-1.0-pro-001": 1,
|
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
"gemini-1.5-pro-latest": 1,
|
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
"gemini-1.5-flash-latest": 1,
|
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
||||||
"gemini-1.0-pro-latest": 1,
|
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
||||||
"gemini-1.0-pro-vision-latest": 1,
|
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"gemini-ultra": 1,
|
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
|
||||||
"glm-4": 7.143, // ¥0.1 / 1k tokens
|
|
||||||
"glm-4v": 7.143, // ¥0.1 / 1k tokens
|
|
||||||
"glm-3-turbo": 0.3572,
|
|
||||||
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
|
|
||||||
"qwen-plus": 10, // ¥0.14 / 1k tokens
|
|
||||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
|
||||||
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
|
|
||||||
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
|
|
||||||
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
|
|
||||||
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
|
|
||||||
"SparkDesk-v4.0": 1.2858,
|
|
||||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
|
||||||
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
|
|
||||||
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
|
|
||||||
"360gpt-pro": 0.8572, // ¥0.012 / 1k tokens
|
|
||||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
|
||||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
|
||||||
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
|
||||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
|
||||||
// https://platform.lingyiwanwu.com/docs#-计费单元
|
// https://platform.lingyiwanwu.com/docs#-计费单元
|
||||||
// 已经按照 7.2 来换算美元价格
|
// 已经按照 7.2 来换算美元价格
|
||||||
"yi-34b-chat-0205": 0.18,
|
"yi-34b-chat-0205": 0.018,
|
||||||
"yi-34b-chat-200k": 0.864,
|
"yi-34b-chat-200k": 0.0864,
|
||||||
"yi-vl-plus": 0.432,
|
"yi-vl-plus": 0.0432,
|
||||||
"yi-large": 20.0 / 1000 * RMB,
|
|
||||||
"yi-medium": 2.5 / 1000 * RMB,
|
|
||||||
"yi-vision": 6.0 / 1000 * RMB,
|
|
||||||
"yi-medium-200k": 12.0 / 1000 * RMB,
|
|
||||||
"yi-spark": 1.0 / 1000 * RMB,
|
|
||||||
"yi-large-rag": 25.0 / 1000 * RMB,
|
|
||||||
"yi-large-turbo": 12.0 / 1000 * RMB,
|
|
||||||
"yi-large-preview": 20.0 / 1000 * RMB,
|
|
||||||
"yi-large-rag-preview": 25.0 / 1000 * RMB,
|
|
||||||
"command": 0.5,
|
"command": 0.5,
|
||||||
"command-nightly": 0.5,
|
"command-nightly": 0.5,
|
||||||
"command-light": 0.5,
|
"command-light": 0.5,
|
||||||
@@ -150,15 +114,9 @@ var defaultModelRatio = map[string]float64{
|
|||||||
"command-r-plus ": 1.5,
|
"command-r-plus ": 1.5,
|
||||||
"deepseek-chat": 0.07,
|
"deepseek-chat": 0.07,
|
||||||
"deepseek-coder": 0.07,
|
"deepseek-coder": 0.07,
|
||||||
// Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用
|
|
||||||
"llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD,
|
|
||||||
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
|
|
||||||
"llama-3-sonar-large-32k-chat": 1 / 1000 * USD,
|
|
||||||
"llama-3-sonar-large-32k-online": 1 / 1000 * USD,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultModelPrice = map[string]float64{
|
var DefaultModelPrice = map[string]float64{
|
||||||
"dall-e-3": 0.04,
|
|
||||||
"gpt-4-gizmo-*": 0.1,
|
"gpt-4-gizmo-*": 0.1,
|
||||||
"mj_imagine": 0.1,
|
"mj_imagine": 0.1,
|
||||||
"mj_variation": 0.1,
|
"mj_variation": 0.1,
|
||||||
@@ -180,15 +138,9 @@ var defaultModelPrice = map[string]float64{
|
|||||||
var modelPrice map[string]float64 = nil
|
var modelPrice map[string]float64 = nil
|
||||||
var modelRatio map[string]float64 = nil
|
var modelRatio map[string]float64 = nil
|
||||||
|
|
||||||
var CompletionRatio map[string]float64 = nil
|
|
||||||
var defaultCompletionRatio = map[string]float64{
|
|
||||||
"gpt-4-gizmo-*": 2,
|
|
||||||
"gpt-4-all": 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
func ModelPrice2JSONString() string {
|
func ModelPrice2JSONString() string {
|
||||||
if modelPrice == nil {
|
if modelPrice == nil {
|
||||||
modelPrice = defaultModelPrice
|
modelPrice = DefaultModelPrice
|
||||||
}
|
}
|
||||||
jsonBytes, err := json.Marshal(modelPrice)
|
jsonBytes, err := json.Marshal(modelPrice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -202,10 +154,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
|
|||||||
return json.Unmarshal([]byte(jsonStr), &modelPrice)
|
return json.Unmarshal([]byte(jsonStr), &modelPrice)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
func GetModelPrice(name string, printErr bool) float64 {
|
||||||
func GetModelPrice(name string, printErr bool) (float64, bool) {
|
|
||||||
if modelPrice == nil {
|
if modelPrice == nil {
|
||||||
modelPrice = defaultModelPrice
|
modelPrice = DefaultModelPrice
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||||
name = "gpt-4-gizmo-*"
|
name = "gpt-4-gizmo-*"
|
||||||
@@ -215,21 +166,14 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
|
|||||||
if printErr {
|
if printErr {
|
||||||
SysError("model price not found: " + name)
|
SysError("model price not found: " + name)
|
||||||
}
|
}
|
||||||
return -1, false
|
return -1
|
||||||
}
|
}
|
||||||
return price, true
|
return price
|
||||||
}
|
|
||||||
|
|
||||||
func GetModelPriceMap() map[string]float64 {
|
|
||||||
if modelPrice == nil {
|
|
||||||
modelPrice = defaultModelPrice
|
|
||||||
}
|
|
||||||
return modelPrice
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
func ModelRatio2JSONString() string {
|
||||||
if modelRatio == nil {
|
if modelRatio == nil {
|
||||||
modelRatio = defaultModelRatio
|
modelRatio = DefaultModelRatio
|
||||||
}
|
}
|
||||||
jsonBytes, err := json.Marshal(modelRatio)
|
jsonBytes, err := json.Marshal(modelRatio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -245,7 +189,7 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
|
|||||||
|
|
||||||
func GetModelRatio(name string) float64 {
|
func GetModelRatio(name string) float64 {
|
||||||
if modelRatio == nil {
|
if modelRatio == nil {
|
||||||
modelRatio = defaultModelRatio
|
modelRatio = DefaultModelRatio
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||||
name = "gpt-4-gizmo-*"
|
name = "gpt-4-gizmo-*"
|
||||||
@@ -258,60 +202,31 @@ func GetModelRatio(name string) float64 {
|
|||||||
return ratio
|
return ratio
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultModelRatio2JSONString() string {
|
|
||||||
jsonBytes, err := json.Marshal(defaultModelRatio)
|
|
||||||
if err != nil {
|
|
||||||
SysError("error marshalling model ratio: " + err.Error())
|
|
||||||
}
|
|
||||||
return string(jsonBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetDefaultModelRatioMap() map[string]float64 {
|
|
||||||
return defaultModelRatio
|
|
||||||
}
|
|
||||||
|
|
||||||
func CompletionRatio2JSONString() string {
|
|
||||||
if CompletionRatio == nil {
|
|
||||||
CompletionRatio = defaultCompletionRatio
|
|
||||||
}
|
|
||||||
jsonBytes, err := json.Marshal(CompletionRatio)
|
|
||||||
if err != nil {
|
|
||||||
SysError("error marshalling completion ratio: " + err.Error())
|
|
||||||
}
|
|
||||||
return string(jsonBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
|
||||||
CompletionRatio = make(map[string]float64)
|
|
||||||
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetCompletionRatio(name string) float64 {
|
func GetCompletionRatio(name string) float64 {
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
|
||||||
name = "gpt-4-gizmo-*"
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(name, "gpt-3.5") {
|
if strings.HasPrefix(name, "gpt-3.5") {
|
||||||
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
|
if strings.HasSuffix(name, "0125") {
|
||||||
// https://openai.com/blog/new-embedding-models-and-api-updates
|
|
||||||
// Updated GPT-3.5 Turbo model and lower pricing
|
|
||||||
return 3
|
return 3
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(name, "1106") {
|
if strings.HasSuffix(name, "1106") {
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
return 4.0 / 3.0
|
if name == "gpt-3.5-turbo" {
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1.333333
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
|
if strings.HasPrefix(name, "gpt-4") {
|
||||||
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4o") {
|
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") {
|
||||||
return 3
|
return 3
|
||||||
}
|
}
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
if strings.Contains(name, "claude-instant-1") {
|
if strings.HasPrefix(name, "claude-instant-1") {
|
||||||
return 3
|
return 3
|
||||||
} else if strings.Contains(name, "claude-2") {
|
} else if strings.HasPrefix(name, "claude-2") {
|
||||||
return 3
|
return 3
|
||||||
} else if strings.Contains(name, "claude-3") {
|
} else if strings.HasPrefix(name, "claude-3") {
|
||||||
return 5
|
return 5
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "mistral-") {
|
if strings.HasPrefix(name, "mistral-") {
|
||||||
@@ -333,32 +248,9 @@ func GetCompletionRatio(name string) float64 {
|
|||||||
if strings.HasPrefix(name, "deepseek") {
|
if strings.HasPrefix(name, "deepseek") {
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "ERNIE-Speed-") {
|
|
||||||
return 2
|
|
||||||
} else if strings.HasPrefix(name, "ERNIE-Lite-") {
|
|
||||||
return 2
|
|
||||||
} else if strings.HasPrefix(name, "ERNIE-Character") {
|
|
||||||
return 2
|
|
||||||
} else if strings.HasPrefix(name, "ERNIE-Functions") {
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
switch name {
|
switch name {
|
||||||
case "llama2-70b-4096":
|
case "llama2-70b-4096":
|
||||||
return 0.8 / 0.64
|
return 0.8 / 0.7
|
||||||
case "llama3-8b-8192":
|
|
||||||
return 2
|
|
||||||
case "llama3-70b-8192":
|
|
||||||
return 0.79 / 0.59
|
|
||||||
}
|
|
||||||
if ratio, ok := CompletionRatio[name]; ok {
|
|
||||||
return ratio
|
|
||||||
}
|
}
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetCompletionRatioMap() map[string]float64 {
|
|
||||||
if CompletionRatio == nil {
|
|
||||||
CompletionRatio = defaultCompletionRatio
|
|
||||||
}
|
|
||||||
return CompletionRatio
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,12 +1,4 @@
|
|||||||
package service
|
package common
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
goahocorasick "github.com/anknown/ahocorasick"
|
|
||||||
"one-api/constant"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func SundaySearch(text string, pattern string) bool {
|
func SundaySearch(text string, pattern string) bool {
|
||||||
// 计算偏移表
|
// 计算偏移表
|
||||||
@@ -56,25 +48,3 @@ func RemoveDuplicate(s []string) []string {
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitAc() *goahocorasick.Machine {
|
|
||||||
m := new(goahocorasick.Machine)
|
|
||||||
dict := readRunes()
|
|
||||||
if err := m.Build(dict); err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func readRunes() [][]rune {
|
|
||||||
var dict [][]rune
|
|
||||||
|
|
||||||
for _, word := range constant.SensitiveWords {
|
|
||||||
word = strings.ToLower(word)
|
|
||||||
l := bytes.TrimSpace([]byte(word))
|
|
||||||
dict = append(dict, bytes.Runes(l))
|
|
||||||
}
|
|
||||||
|
|
||||||
return dict
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import "encoding/json"
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
var TopupGroupRatio = map[string]float64{
|
var TopupGroupRatio = map[string]float64{
|
||||||
"default": 1,
|
"default": 1,
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"html/template"
|
"html/template"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -190,6 +190,25 @@ func Max(a int, b int) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetOrDefault(env string, defaultValue int) int {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
num, err := strconv.Atoi(os.Getenv(env))
|
||||||
|
if err != nil {
|
||||||
|
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetOrDefaultString(env string, defaultValue string) string {
|
||||||
|
if env == "" || os.Getenv(env) == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return os.Getenv(env)
|
||||||
|
}
|
||||||
|
|
||||||
func MessageWithRequestId(message string, id string) string {
|
func MessageWithRequestId(message string, id string) string {
|
||||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||||
}
|
}
|
||||||
@@ -222,28 +241,3 @@ func RandomSleep() {
|
|||||||
// Sleep for 0-3000 ms
|
// Sleep for 0-3000 ms
|
||||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MapToJsonStr(m map[string]interface{}) string {
|
|
||||||
bytes, err := json.Marshal(m)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return string(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MapToJsonStrFloat(m map[string]float64) string {
|
|
||||||
bytes, err := json.Marshal(m)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return string(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func StrToMap(str string) map[string]interface{} {
|
|
||||||
m := make(map[string]interface{})
|
|
||||||
err := json.Unmarshal([]byte(str), &m)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
package constant
|
|
||||||
|
|
||||||
import (
|
|
||||||
"one-api/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
|
||||||
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
|
||||||
|
|
||||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
|
||||||
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
|
||||||
@@ -4,7 +4,6 @@ var MjNotifyEnabled = false
|
|||||||
var MjAccountFilterEnabled = false
|
var MjAccountFilterEnabled = false
|
||||||
var MjModeClearEnabled = false
|
var MjModeClearEnabled = false
|
||||||
var MjForwardUrlEnabled = true
|
var MjForwardUrlEnabled = true
|
||||||
var MjActionCheckSuccessEnabled = true
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MjErrorUnknown = 5
|
MjErrorUnknown = 5
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ var StreamCacheQueueLength = 0
|
|||||||
// SensitiveWords 敏感词
|
// SensitiveWords 敏感词
|
||||||
// var SensitiveWords []string
|
// var SensitiveWords []string
|
||||||
var SensitiveWords = []string{
|
var SensitiveWords = []string{
|
||||||
"test_sensitive",
|
"test",
|
||||||
}
|
}
|
||||||
|
|
||||||
func SensitiveWordsToString() string {
|
func SensitiveWordsToString() string {
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
package constant
|
|
||||||
|
|
||||||
var ServerAddress = "http://localhost:3000"
|
|
||||||
var WorkerUrl = ""
|
|
||||||
var WorkerValidKey = ""
|
|
||||||
|
|
||||||
func EnableWorker() bool {
|
|
||||||
return WorkerUrl != ""
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package constant
|
|
||||||
|
|
||||||
type TaskPlatform string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TaskPlatformSuno TaskPlatform = "suno"
|
|
||||||
TaskPlatformMidjourney = "mj"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
SunoActionMusic = "MUSIC"
|
|
||||||
SunoActionLyrics = "LYRICS"
|
|
||||||
)
|
|
||||||
|
|
||||||
var SunoModel2Action = map[string]string{
|
|
||||||
"suno_music": SunoActionMusic,
|
|
||||||
"suno_lyrics": SunoActionLyrics,
|
|
||||||
}
|
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -25,13 +24,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
||||||
tik := time.Now()
|
|
||||||
if channel.Type == common.ChannelTypeMidjourney {
|
if channel.Type == common.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return errors.New("midjourney channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeSunoAPI {
|
|
||||||
return errors.New("suno channel test is not supported"), nil
|
|
||||||
}
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
c.Request = &http.Request{
|
c.Request = &http.Request{
|
||||||
@@ -58,7 +53,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
}
|
}
|
||||||
|
|
||||||
meta := relaycommon.GenRelayInfo(c)
|
meta := relaycommon.GenRelayInfo(c)
|
||||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
apiType := constant.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||||
@@ -67,27 +62,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||||
testModel = *channel.TestModel
|
testModel = *channel.TestModel
|
||||||
} else {
|
} else {
|
||||||
if len(channel.GetModels()) > 0 {
|
testModel = adaptor.GetModelList()[0]
|
||||||
testModel = channel.GetModels()[0]
|
|
||||||
} else {
|
|
||||||
testModel = "gpt-3.5-turbo"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
modelMapping := *channel.ModelMapping
|
|
||||||
if modelMapping != "" && modelMapping != "{}" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error
|
|
||||||
return err, &openaiErr
|
|
||||||
}
|
|
||||||
if modelMap[testModel] != "" {
|
|
||||||
testModel = modelMap[testModel]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildTestRequest()
|
request := buildTestRequest()
|
||||||
request.Model = testModel
|
request.Model = testModel
|
||||||
meta.UpstreamModelName = testModel
|
meta.UpstreamModelName = testModel
|
||||||
@@ -126,25 +103,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
modelPrice, usePrice := common.GetModelPrice(testModel, false)
|
|
||||||
modelRatio := common.GetModelRatio(testModel)
|
|
||||||
completionRatio := common.GetCompletionRatio(testModel)
|
|
||||||
ratio := modelRatio
|
|
||||||
quota := 0
|
|
||||||
if !usePrice {
|
|
||||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
|
|
||||||
quota = int(math.Round(float64(quota) * ratio))
|
|
||||||
if ratio != 0 && quota <= 0 {
|
|
||||||
quota = 1
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
quota = int(modelPrice * common.QuotaPerUnit)
|
|
||||||
}
|
|
||||||
tok := time.Now()
|
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
|
||||||
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
|
|
||||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
|
|
||||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -165,7 +123,7 @@ func buildTestRequest() *dto.GeneralOpenAIRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestChannel(c *gin.Context) {
|
func TestChannel(c *gin.Context) {
|
||||||
channelId, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -173,7 +131,7 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err := model.GetChannelById(channelId, true)
|
channel, err := model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -247,18 +205,11 @@ func testAllChannels(notify bool) error {
|
|||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||||
ban = false
|
ban = false
|
||||||
}
|
}
|
||||||
if openaiErr != nil {
|
if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
|
||||||
openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{
|
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||||
StatusCode: -1,
|
}
|
||||||
Error: *openaiErr,
|
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
|
||||||
LocalError: false,
|
service.EnableChannel(channel.Id, channel.Name)
|
||||||
}
|
|
||||||
if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban {
|
|
||||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
|
||||||
}
|
|
||||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
|
|
||||||
service.EnableChannel(channel.Id, channel.Name)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -11,34 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIModel struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
OwnedBy string `json:"owned_by"`
|
|
||||||
Permission []struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
|
||||||
AllowSampling bool `json:"allow_sampling"`
|
|
||||||
AllowLogprobs bool `json:"allow_logprobs"`
|
|
||||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
|
||||||
AllowView bool `json:"allow_view"`
|
|
||||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
|
||||||
Organization string `json:"organization"`
|
|
||||||
Group string `json:"group"`
|
|
||||||
IsBlocking bool `json:"is_blocking"`
|
|
||||||
} `json:"permission"`
|
|
||||||
Root string `json:"root"`
|
|
||||||
Parent string `json:"parent"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIModelsResponse struct {
|
|
||||||
Data []OpenAIModel `json:"data"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAllChannels(c *gin.Context) {
|
func GetAllChannels(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
p, _ := strconv.Atoi(c.Query("p"))
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||||
@@ -65,65 +35,6 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func FetchUpstreamModels(c *gin.Context) {
|
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
channel, err := model.GetChannelById(id, true)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if channel.Type != common.ChannelTypeOpenAI {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": "仅支持 OpenAI 类型渠道",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
|
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
result := OpenAIModelsResponse{}
|
|
||||||
err = json.Unmarshal(body, &result)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if !result.Success {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": "上游返回错误",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var ids []string
|
|
||||||
for _, model := range result.Data {
|
|
||||||
ids = append(ids, model.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "",
|
|
||||||
"data": ids,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func FixChannelsAbilities(c *gin.Context) {
|
func FixChannelsAbilities(c *gin.Context) {
|
||||||
count, err := model.FixAbility()
|
count, err := model.FixAbility()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -123,6 +123,8 @@ func GitHubOAuth(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if common.RegisterEnabled {
|
if common.RegisterEnabled {
|
||||||
|
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
|
||||||
|
|
||||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||||
if githubUser.Name != "" {
|
if githubUser.Name != "" {
|
||||||
user.DisplayName = githubUser.Name
|
user.DisplayName = githubUser.Name
|
||||||
@@ -133,7 +135,7 @@ func GitHubOAuth(c *gin.Context) {
|
|||||||
user.Role = common.RoleCommonUser
|
user.Role = common.RoleCommonUser
|
||||||
user.Status = common.UserStatusEnabled
|
user.Status = common.UserStatusEnabled
|
||||||
|
|
||||||
if err := user.Insert(0); err != nil {
|
if err := user.Insert(user.InviterId); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
|||||||
239
controller/linuxdo.go
Normal file
239
controller/linuxdo.go
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LinuxDoOAuthResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LinuxDoUser struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Active bool `json:"active"`
|
||||||
|
TrustLevel int `json:"trust_level"`
|
||||||
|
Silenced bool `json:"silenced"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
|
||||||
|
if code == "" {
|
||||||
|
return nil, errors.New("无效的参数")
|
||||||
|
}
|
||||||
|
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Authorization", "Basic "+auth)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
client := http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog(err.Error())
|
||||||
|
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
var oAuthResponse LinuxDoOAuthResponse
|
||||||
|
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||||
|
res2, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog(err.Error())
|
||||||
|
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
|
||||||
|
}
|
||||||
|
defer res2.Body.Close()
|
||||||
|
var linuxdoUser LinuxDoUser
|
||||||
|
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if linuxdoUser.ID == 0 {
|
||||||
|
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||||
|
}
|
||||||
|
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
|
||||||
|
return nil, errors.New("用户 LINUX DO 信任等级不足!")
|
||||||
|
}
|
||||||
|
return &linuxdoUser, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxDoOAuth(c *gin.Context) {
|
||||||
|
session := sessions.Default(c)
|
||||||
|
state := c.Query("state")
|
||||||
|
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "state is empty or not same",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
username := session.Get("username")
|
||||||
|
if username != nil {
|
||||||
|
LinuxDoBind(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !common.LinuxDoOAuthEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "管理员未开启通过 LINUX DO 登录以及注册",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := c.Query("code")
|
||||||
|
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user := model.User{
|
||||||
|
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
|
||||||
|
LinuxDoLevel: linuxdoUser.TrustLevel,
|
||||||
|
}
|
||||||
|
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
|
||||||
|
err := user.FillUserByLinuxDoId()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user.LinuxDoLevel = linuxdoUser.TrustLevel
|
||||||
|
err = user.Update(false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if common.RegisterEnabled {
|
||||||
|
affCode := c.Query("aff")
|
||||||
|
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
|
||||||
|
|
||||||
|
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||||
|
if linuxdoUser.Name != "" {
|
||||||
|
user.DisplayName = linuxdoUser.Name
|
||||||
|
} else {
|
||||||
|
user.DisplayName = linuxdoUser.Username
|
||||||
|
}
|
||||||
|
user.Role = common.RoleCommonUser
|
||||||
|
user.Status = common.UserStatusEnabled
|
||||||
|
|
||||||
|
if err := user.Insert(user.InviterId); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "管理员关闭了新用户注册",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Status != common.UserStatusEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "用户已被封禁",
|
||||||
|
"success": false,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setupLogin(&user, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxDoBind(c *gin.Context) {
|
||||||
|
if !common.LinuxDoOAuthEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "管理员未开启通过 LINUX DO 登录以及注册",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := c.Query("code")
|
||||||
|
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user := model.User{
|
||||||
|
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
|
||||||
|
LinuxDoLevel: linuxdoUser.TrustLevel,
|
||||||
|
}
|
||||||
|
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "该 LINUX DO 账户已被绑定",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session := sessions.Default(c)
|
||||||
|
id := session.Get("id")
|
||||||
|
// id := c.GetInt("id") // critical bug!
|
||||||
|
user.Id = id.(int)
|
||||||
|
err = user.FillUserById()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
|
||||||
|
user.LinuxDoLevel = linuxdoUser.TrustLevel
|
||||||
|
err = user.Update(false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "bind",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -24,7 +24,7 @@ func GetAllLogs(c *gin.Context) {
|
|||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
|
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -35,6 +35,7 @@ func GetAllLogs(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
|
"total": total,
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@@ -58,7 +59,7 @@ func GetUserLogs(c *gin.Context) {
|
|||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
|
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -69,6 +70,7 @@ func GetUserLogs(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
|
"total": total,
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -235,7 +235,7 @@ func GetAllMidjourney(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if constant.MjForwardUrlEnabled {
|
if constant.MjForwardUrlEnabled {
|
||||||
for i, midjourney := range logs {
|
for i, midjourney := range logs {
|
||||||
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
|
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||||
logs[i] = midjourney
|
logs[i] = midjourney
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -267,7 +267,7 @@ func GetUserMidjourney(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if constant.MjForwardUrlEnabled {
|
if constant.MjForwardUrlEnabled {
|
||||||
for i, midjourney := range logs {
|
for i, midjourney := range logs {
|
||||||
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
|
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||||
logs[i] = midjourney
|
logs[i] = midjourney
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ func GetStatus(c *gin.Context) {
|
|||||||
"email_verification": common.EmailVerificationEnabled,
|
"email_verification": common.EmailVerificationEnabled,
|
||||||
"github_oauth": common.GitHubOAuthEnabled,
|
"github_oauth": common.GitHubOAuthEnabled,
|
||||||
"github_client_id": common.GitHubClientId,
|
"github_client_id": common.GitHubClientId,
|
||||||
|
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
|
||||||
|
"linuxdo_client_id": common.LinuxDoClientId,
|
||||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||||
"telegram_bot_name": common.TelegramBotName,
|
"telegram_bot_name": common.TelegramBotName,
|
||||||
"system_name": common.SystemName,
|
"system_name": common.SystemName,
|
||||||
@@ -45,9 +47,9 @@ func GetStatus(c *gin.Context) {
|
|||||||
"footer_html": common.Footer,
|
"footer_html": common.Footer,
|
||||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||||
"wechat_login": common.WeChatAuthEnabled,
|
"wechat_login": common.WeChatAuthEnabled,
|
||||||
"server_address": constant.ServerAddress,
|
"server_address": common.ServerAddress,
|
||||||
"price": constant.Price,
|
"stripe_unit_price": common.StripeUnitPrice,
|
||||||
"min_topup": constant.MinTopUp,
|
"min_topup": common.MinTopUp,
|
||||||
"turnstile_check": common.TurnstileCheckEnabled,
|
"turnstile_check": common.TurnstileCheckEnabled,
|
||||||
"turnstile_site_key": common.TurnstileSiteKey,
|
"turnstile_site_key": common.TurnstileSiteKey,
|
||||||
"top_up_link": common.TopUpLink,
|
"top_up_link": common.TopUpLink,
|
||||||
@@ -57,11 +59,10 @@ func GetStatus(c *gin.Context) {
|
|||||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||||
"enable_batch_update": common.BatchUpdateEnabled,
|
"enable_batch_update": common.BatchUpdateEnabled,
|
||||||
"enable_drawing": common.DrawingEnabled,
|
"enable_drawing": common.DrawingEnabled,
|
||||||
"enable_task": common.TaskEnabled,
|
|
||||||
"enable_data_export": common.DataExportEnabled,
|
"enable_data_export": common.DataExportEnabled,
|
||||||
"data_export_default_time": common.DataExportDefaultTime,
|
"data_export_default_time": common.DataExportDefaultTime,
|
||||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||||
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "",
|
"payment_enabled": common.PaymentEnabled,
|
||||||
"mj_notify_enabled": constant.MjNotifyEnabled,
|
"mj_notify_enabled": constant.MjNotifyEnabled,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -148,7 +149,7 @@ func SendEmailVerification(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if common.EmailAliasRestrictionEnabled {
|
if common.EmailAliasRestrictionEnabled {
|
||||||
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Contains(localPart, ".")
|
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1
|
||||||
if containsSpecialSymbols {
|
if containsSpecialSymbols {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -204,7 +205,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
code := common.GenerateVerificationCode(0)
|
code := common.GenerateVerificationCode(0)
|
||||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code)
|
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
|
||||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||||
|
|||||||
@@ -4,28 +4,49 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/channel/ai360"
|
"one-api/relay/channel/ai360"
|
||||||
"one-api/relay/channel/lingyiwanwu"
|
|
||||||
"one-api/relay/channel/minimax"
|
|
||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
relaycommon "one-api/relay/common"
|
"one-api/relay/channel/lingyiwanwu"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
|
|
||||||
var openAIModels []dto.OpenAIModels
|
type OpenAIModelPermission struct {
|
||||||
var openAIModelsMap map[string]dto.OpenAIModels
|
Id string `json:"id"`
|
||||||
var channelId2Models map[int][]string
|
Object string `json:"object"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
AllowCreateEngine bool `json:"allow_create_engine"`
|
||||||
|
AllowSampling bool `json:"allow_sampling"`
|
||||||
|
AllowLogprobs bool `json:"allow_logprobs"`
|
||||||
|
AllowSearchIndices bool `json:"allow_search_indices"`
|
||||||
|
AllowView bool `json:"allow_view"`
|
||||||
|
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||||||
|
Organization string `json:"organization"`
|
||||||
|
Group *string `json:"group"`
|
||||||
|
IsBlocking bool `json:"is_blocking"`
|
||||||
|
}
|
||||||
|
|
||||||
func getPermission() []dto.OpenAIModelPermission {
|
type OpenAIModels struct {
|
||||||
var permission []dto.OpenAIModelPermission
|
Id string `json:"id"`
|
||||||
permission = append(permission, dto.OpenAIModelPermission{
|
Object string `json:"object"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
Permission []OpenAIModelPermission `json:"permission"`
|
||||||
|
Root string `json:"root"`
|
||||||
|
Parent *string `json:"parent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var openAIModels []OpenAIModels
|
||||||
|
var openAIModelsMap map[string]OpenAIModels
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var permission []OpenAIModelPermission
|
||||||
|
permission = append(permission, OpenAIModelPermission{
|
||||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
||||||
Object: "model_permission",
|
Object: "model_permission",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
@@ -39,12 +60,7 @@ func getPermission() []dto.OpenAIModelPermission {
|
|||||||
Group: nil,
|
Group: nil,
|
||||||
IsBlocking: false,
|
IsBlocking: false,
|
||||||
})
|
})
|
||||||
return permission
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
permission := getPermission()
|
|
||||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
||||||
if i == relayconstant.APITypeAIProxyLibrary {
|
if i == relayconstant.APITypeAIProxyLibrary {
|
||||||
continue
|
continue
|
||||||
@@ -53,7 +69,7 @@ func init() {
|
|||||||
channelName := adaptor.GetChannelName()
|
channelName := adaptor.GetChannelName()
|
||||||
modelNames := adaptor.GetModelList()
|
modelNames := adaptor.GetModelList()
|
||||||
for _, modelName := range modelNames {
|
for _, modelName := range modelNames {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
@@ -65,51 +81,40 @@ func init() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, modelName := range ai360.ModelList {
|
for _, modelName := range ai360.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: ai360.ChannelName,
|
OwnedBy: "360",
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Root: modelName,
|
Root: modelName,
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range moonshot.ModelList {
|
for _, modelName := range moonshot.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: moonshot.ChannelName,
|
OwnedBy: "moonshot",
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Root: modelName,
|
Root: modelName,
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range lingyiwanwu.ModelList {
|
for _, modelName := range lingyiwanwu.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: lingyiwanwu.ChannelName,
|
OwnedBy: "lingyiwanwu",
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
for _, modelName := range minimax.ModelList {
|
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
|
||||||
Id: modelName,
|
|
||||||
Object: "model",
|
|
||||||
Created: 1626777600,
|
|
||||||
OwnedBy: minimax.ChannelName,
|
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Root: modelName,
|
Root: modelName,
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
@@ -119,20 +124,9 @@ func init() {
|
|||||||
Parent: nil,
|
Parent: nil,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
openAIModelsMap = make(map[string]OpenAIModels)
|
||||||
for _, aiModel := range openAIModels {
|
for _, model := range openAIModels {
|
||||||
openAIModelsMap[aiModel.Id] = aiModel
|
openAIModelsMap[model.Id] = model
|
||||||
}
|
|
||||||
channelId2Models = make(map[int][]string)
|
|
||||||
for i := 1; i <= common.ChannelTypeDummy; i++ {
|
|
||||||
apiType, success := relayconstant.ChannelType2APIType(i)
|
|
||||||
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
|
||||||
adaptor.Init(meta, dto.GeneralOpenAIRequest{})
|
|
||||||
channelId2Models[i] = adaptor.GetModelList()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,47 +141,29 @@ func ListModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
models := model.GetGroupModels(user.Group)
|
models := model.GetGroupModels(user.Group)
|
||||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
userOpenAiModels := make([]OpenAIModels, 0)
|
||||||
permission := getPermission()
|
|
||||||
for _, s := range models {
|
for _, s := range models {
|
||||||
if _, ok := openAIModelsMap[s]; ok {
|
if _, ok := openAIModelsMap[s]; ok {
|
||||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||||
} else {
|
|
||||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
|
||||||
Id: s,
|
|
||||||
Object: "model",
|
|
||||||
Created: 1626777600,
|
|
||||||
OwnedBy: "custom",
|
|
||||||
Permission: permission,
|
|
||||||
Root: s,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": true,
|
"object": "list",
|
||||||
"data": userOpenAiModels,
|
"data": userOpenAiModels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ChannelListModels(c *gin.Context) {
|
func ChannelListModels(c *gin.Context) {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": true,
|
"object": "list",
|
||||||
"data": openAIModels,
|
"data": openAIModels,
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func DashboardListModels(c *gin.Context) {
|
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"data": channelId2Models,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func RetrieveModel(c *gin.Context) {
|
func RetrieveModel(c *gin.Context) {
|
||||||
modelId := c.Param("model")
|
modelId := c.Param("model")
|
||||||
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
if model, ok := openAIModelsMap[modelId]; ok {
|
||||||
c.JSON(200, aiModel)
|
c.JSON(200, model)
|
||||||
} else {
|
} else {
|
||||||
openAIError := dto.OpenAIError{
|
openAIError := dto.OpenAIError{
|
||||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||||
|
|||||||
@@ -50,6 +50,14 @@ func UpdateOption(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
case "LinuxDoOAuthEnabled":
|
||||||
|
if option.Value == "true" && common.LinuxDoClientId == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无法启用 LINUX DO OAuth,请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret!",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
case "EmailDomainRestrictionEnabled":
|
case "EmailDomainRestrictionEnabled":
|
||||||
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
|
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func GetPricing(c *gin.Context) {
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
// if no login, get default group ratio
|
|
||||||
groupRatio := common.GetGroupRatio("default")
|
|
||||||
group, err := model.CacheGetUserGroup(userId)
|
|
||||||
if err == nil {
|
|
||||||
groupRatio = common.GetGroupRatio(group)
|
|
||||||
}
|
|
||||||
pricing := model.GetPricing(group)
|
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"data": pricing,
|
|
||||||
"group_ratio": groupRatio,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func ResetModelRatio(c *gin.Context) {
|
|
||||||
defaultStr := common.DefaultModelRatio2JSONString()
|
|
||||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = common.UpdateModelRatioByJSONString(defaultStr)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "重置模型倍率成功",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -29,8 +29,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err = relay.AudioHelper(c, relayMode)
|
err = relay.AudioHelper(c, relayMode)
|
||||||
case relayconstant.RelayModeRerank:
|
|
||||||
err = relay.RerankHelper(c, relayMode)
|
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
@@ -42,13 +40,12 @@ func Relay(c *gin.Context) {
|
|||||||
retryTimes := common.RetryTimes
|
retryTimes := common.RetryTimes
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
channelType := c.GetInt("channel_type")
|
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
openaiErr := relayHandler(c, relayMode)
|
openaiErr := relayHandler(c, relayMode)
|
||||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
useChannel := []int{channelId}
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
go processChannelError(c, channelId, channelType, openaiErr)
|
go processChannelError(c, channelId, openaiErr)
|
||||||
} else {
|
} else {
|
||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
@@ -59,9 +56,7 @@ func Relay(c *gin.Context) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
channelId = channel.Id
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel = append(useChannel, channelId)
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
|
||||||
c.Set("use_channel", useChannel)
|
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
@@ -69,10 +64,9 @@ func Relay(c *gin.Context) {
|
|||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
openaiErr = relayHandler(c, relayMode)
|
openaiErr = relayHandler(c, relayMode)
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
go processChannelError(c, channelId, channel.Type, openaiErr)
|
go processChannelError(c, channelId, openaiErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
common.LogInfo(c.Request.Context(), retryLogStr)
|
||||||
@@ -128,10 +122,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) {
|
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
autoBan := c.GetBool("auto_ban")
|
autoBan := c.GetBool("auto_ban")
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
||||||
channelName := c.GetString("channel_name")
|
channelName := c.GetString("channel_name")
|
||||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||||
}
|
}
|
||||||
@@ -193,94 +187,3 @@ func RelayNotFound(c *gin.Context) {
|
|||||||
"error": err,
|
"error": err,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayTask(c *gin.Context) {
|
|
||||||
retryTimes := common.RetryTimes
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
relayMode := c.GetInt("relay_mode")
|
|
||||||
group := c.GetString("group")
|
|
||||||
originalModel := c.GetString("original_model")
|
|
||||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
|
||||||
taskErr := taskRelayHandler(c, relayMode)
|
|
||||||
if taskErr == nil {
|
|
||||||
retryTimes = 0
|
|
||||||
}
|
|
||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
channelId = channel.Id
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
|
||||||
c.Set("use_channel", useChannel)
|
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
|
||||||
|
|
||||||
requestBody, err := common.GetRequestBody(c)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
taskErr = taskRelayHandler(c, relayMode)
|
|
||||||
}
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
if len(useChannel) > 1 {
|
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
|
||||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
|
||||||
}
|
|
||||||
if taskErr != nil {
|
|
||||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
|
||||||
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
|
|
||||||
}
|
|
||||||
c.JSON(taskErr.StatusCode, taskErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
|
||||||
var err *dto.TaskError
|
|
||||||
switch relayMode {
|
|
||||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
|
||||||
err = relay.RelayTaskFetch(c, relayMode)
|
|
||||||
default:
|
|
||||||
err = relay.RelayTaskSubmit(c, relayMode)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
|
|
||||||
if taskErr == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if retryTimes <= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if _, ok := c.Get("specific_channel_id"); ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if taskErr.StatusCode == 307 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if taskErr.StatusCode/100 == 5 {
|
|
||||||
// 超时不重试
|
|
||||||
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if taskErr.StatusCode == http.StatusBadRequest {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if taskErr.StatusCode == 408 {
|
|
||||||
// azure处理超时不重试
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if taskErr.LocalError {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if taskErr.StatusCode/100 == 2 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|||||||
97
controller/stripe.go
Normal file
97
controller/stripe.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stripe/stripe-go/v76"
|
||||||
|
"github.com/stripe/stripe-go/v76/webhook"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func StripeWebhook(c *gin.Context) {
|
||||||
|
payload, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
||||||
|
c.AbortWithStatus(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
signature := c.GetHeader("Stripe-Signature")
|
||||||
|
endpointSecret := common.StripeWebhookSecret
|
||||||
|
event, err := webhook.ConstructEvent(payload, signature, endpointSecret)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
||||||
|
c.AbortWithStatus(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch event.Type {
|
||||||
|
case stripe.EventTypeCheckoutSessionCompleted:
|
||||||
|
sessionCompleted(event)
|
||||||
|
case stripe.EventTypeCheckoutSessionExpired:
|
||||||
|
sessionExpired(event)
|
||||||
|
default:
|
||||||
|
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionCompleted(event stripe.Event) {
|
||||||
|
customerId := event.GetObjectValue("customer")
|
||||||
|
referenceId := event.GetObjectValue("client_reference_id")
|
||||||
|
status := event.GetObjectValue("status")
|
||||||
|
if "complete" != status {
|
||||||
|
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := model.Recharge(referenceId, customerId)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err.Error(), referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
||||||
|
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
||||||
|
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionExpired(event stripe.Event) {
|
||||||
|
referenceId := event.GetObjectValue("client_reference_id")
|
||||||
|
status := event.GetObjectValue("status")
|
||||||
|
if "expired" != status {
|
||||||
|
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if "" == referenceId {
|
||||||
|
log.Println("未提供支付单号")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||||
|
if topUp == nil {
|
||||||
|
log.Println("充值订单不存在", referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if topUp.Status != common.TopUpStatusPending {
|
||||||
|
log.Println("充值订单状态错误", referenceId)
|
||||||
|
}
|
||||||
|
|
||||||
|
topUp.Status = common.TopUpStatusExpired
|
||||||
|
err := topUp.Update()
|
||||||
|
if err != nil {
|
||||||
|
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("充值订单已过期", referenceId)
|
||||||
|
}
|
||||||
@@ -1,284 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/samber/lo"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/constant"
|
|
||||||
"one-api/dto"
|
|
||||||
"one-api/model"
|
|
||||||
"one-api/relay"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func UpdateTaskBulk() {
|
|
||||||
//revocer
|
|
||||||
//imageModel := "midjourney"
|
|
||||||
for {
|
|
||||||
time.Sleep(time.Duration(15) * time.Second)
|
|
||||||
common.SysLog("任务进度轮询开始")
|
|
||||||
ctx := context.TODO()
|
|
||||||
allTasks := model.GetAllUnFinishSyncTasks(500)
|
|
||||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
|
||||||
for _, t := range allTasks {
|
|
||||||
platformTask[t.Platform] = append(platformTask[t.Platform], t)
|
|
||||||
}
|
|
||||||
for platform, tasks := range platformTask {
|
|
||||||
if len(tasks) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
taskChannelM := make(map[int][]string)
|
|
||||||
taskM := make(map[string]*model.Task)
|
|
||||||
nullTaskIds := make([]int64, 0)
|
|
||||||
for _, task := range tasks {
|
|
||||||
if task.TaskID == "" {
|
|
||||||
// 统计失败的未完成任务
|
|
||||||
nullTaskIds = append(nullTaskIds, task.ID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
taskM[task.TaskID] = task
|
|
||||||
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
|
|
||||||
}
|
|
||||||
if len(nullTaskIds) > 0 {
|
|
||||||
err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
|
|
||||||
"status": "FAILURE",
|
|
||||||
"progress": "100%",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
|
||||||
} else {
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(taskChannelM) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
UpdateTaskByPlatform(platform, taskChannelM, taskM)
|
|
||||||
}
|
|
||||||
common.SysLog("任务进度轮询完成")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
|
|
||||||
switch platform {
|
|
||||||
case constant.TaskPlatformMidjourney:
|
|
||||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
|
||||||
case constant.TaskPlatformSuno:
|
|
||||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
|
||||||
default:
|
|
||||||
common.SysLog("未知平台")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
|
||||||
for channelId, taskIds := range taskChannelM {
|
|
||||||
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
|
||||||
if len(taskIds) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
channel, err := model.CacheGetChannel(channelId)
|
|
||||||
if err != nil {
|
|
||||||
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
|
|
||||||
err = model.TaskBulkUpdate(taskIds, map[string]any{
|
|
||||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
|
||||||
"status": "FAILURE",
|
|
||||||
"progress": "100%",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
|
|
||||||
if adaptor == nil {
|
|
||||||
return errors.New("adaptor not found")
|
|
||||||
}
|
|
||||||
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
|
|
||||||
"ids": taskIds,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
|
||||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
|
||||||
err = json.Unmarshal(responseBody, &responseItems)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !responseItems.IsSuccess() {
|
|
||||||
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, responseItem := range responseItems.Data {
|
|
||||||
task := taskM[responseItem.TaskID]
|
|
||||||
if !checkTaskNeedUpdate(task, responseItem) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
|
|
||||||
task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
|
|
||||||
task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
|
|
||||||
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
|
||||||
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
|
||||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
|
||||||
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
|
||||||
task.Progress = "100%"
|
|
||||||
err = model.CacheUpdateUserQuota(task.UserId)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
|
||||||
} else {
|
|
||||||
quota := task.Quota
|
|
||||||
if quota != 0 {
|
|
||||||
err = model.IncreaseUserQuota(task.UserId, quota)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
|
||||||
}
|
|
||||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
|
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if responseItem.Status == model.TaskStatusSuccess {
|
|
||||||
task.Progress = "100%"
|
|
||||||
}
|
|
||||||
task.Data = responseItem.Data
|
|
||||||
|
|
||||||
err = task.Update()
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("UpdateMidjourneyTask task error: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
|
|
||||||
|
|
||||||
if oldTask.SubmitTime != newTask.SubmitTime {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if oldTask.StartTime != newTask.StartTime {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if oldTask.FinishTime != newTask.FinishTime {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if string(oldTask.Status) != newTask.Status {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if oldTask.FailReason != newTask.FailReason {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if oldTask.FinishTime != newTask.FinishTime {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
oldData, _ := json.Marshal(oldTask.Data)
|
|
||||||
newData, _ := json.Marshal(newTask.Data)
|
|
||||||
|
|
||||||
sort.Slice(oldData, func(i, j int) bool {
|
|
||||||
return oldData[i] < oldData[j]
|
|
||||||
})
|
|
||||||
sort.Slice(newData, func(i, j int) bool {
|
|
||||||
return newData[i] < newData[j]
|
|
||||||
})
|
|
||||||
|
|
||||||
if string(oldData) != string(newData) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAllTask(c *gin.Context) {
|
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
|
||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
|
||||||
// 解析其他查询参数
|
|
||||||
queryParams := model.SyncTaskQueryParams{
|
|
||||||
Platform: constant.TaskPlatform(c.Query("platform")),
|
|
||||||
TaskID: c.Query("task_id"),
|
|
||||||
Status: c.Query("status"),
|
|
||||||
Action: c.Query("action"),
|
|
||||||
StartTimestamp: startTimestamp,
|
|
||||||
EndTimestamp: endTimestamp,
|
|
||||||
}
|
|
||||||
|
|
||||||
logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
|
||||||
if logs == nil {
|
|
||||||
logs = make([]*model.Task, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "",
|
|
||||||
"data": logs,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUserTask(c *gin.Context) {
|
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
|
|
||||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
|
||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
|
||||||
|
|
||||||
queryParams := model.SyncTaskQueryParams{
|
|
||||||
Platform: constant.TaskPlatform(c.Query("platform")),
|
|
||||||
TaskID: c.Query("task_id"),
|
|
||||||
Status: c.Query("status"),
|
|
||||||
Action: c.Query("action"),
|
|
||||||
StartTimestamp: startTimestamp,
|
|
||||||
EndTimestamp: endTimestamp,
|
|
||||||
}
|
|
||||||
|
|
||||||
logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
|
||||||
if logs == nil {
|
|
||||||
logs = make([]*model.Task, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "",
|
|
||||||
"data": logs,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,22 +1,20 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
|
import "C"
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/Calcium-Ion/go-epay/epay"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/samber/lo"
|
"github.com/stripe/stripe-go/v76"
|
||||||
|
"github.com/stripe/stripe-go/v76/checkout/session"
|
||||||
"log"
|
"log"
|
||||||
"net/url"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EpayRequest struct {
|
type PayRequest struct {
|
||||||
Amount int `json:"amount"`
|
Amount int `json:"amount"`
|
||||||
PaymentMethod string `json:"payment_method"`
|
PaymentMethod string `json:"payment_method"`
|
||||||
TopUpCode string `json:"top_up_code"`
|
TopUpCode string `json:"top_up_code"`
|
||||||
@@ -27,196 +25,114 @@ type AmountRequest struct {
|
|||||||
TopUpCode string `json:"top_up_code"`
|
TopUpCode string `json:"top_up_code"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetEpayClient() *epay.Client {
|
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
|
||||||
if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" {
|
if !strings.HasPrefix(common.StripeApiSecret, "sk_") {
|
||||||
return nil
|
return "", fmt.Errorf("无效的Stripe API密钥")
|
||||||
}
|
}
|
||||||
withUrl, err := epay.NewClient(&epay.Config{
|
|
||||||
PartnerID: constant.EpayId,
|
stripe.Key = common.StripeApiSecret
|
||||||
Key: constant.EpayKey,
|
|
||||||
}, constant.PayAddress)
|
params := &stripe.CheckoutSessionParams{
|
||||||
|
ClientReferenceID: stripe.String(referenceId),
|
||||||
|
SuccessURL: stripe.String(common.ServerAddress + "/log"),
|
||||||
|
CancelURL: stripe.String(common.ServerAddress + "/topup"),
|
||||||
|
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||||
|
{
|
||||||
|
Price: stripe.String(common.StripePriceId),
|
||||||
|
Quantity: stripe.Int64(amount),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "" == customerId {
|
||||||
|
if "" != email {
|
||||||
|
params.CustomerEmail = stripe.String(email)
|
||||||
|
}
|
||||||
|
|
||||||
|
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
|
||||||
|
} else {
|
||||||
|
params.Customer = stripe.String(customerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := session.New(params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return "", err
|
||||||
}
|
}
|
||||||
return withUrl
|
|
||||||
|
return result.URL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPayMoney(amount float64, user model.User) float64 {
|
func GetPayAmount(count float64) float64 {
|
||||||
if !common.DisplayInCurrencyEnabled {
|
return count * common.StripeUnitPrice
|
||||||
amount = amount / common.QuotaPerUnit
|
|
||||||
}
|
|
||||||
// 别问为什么用float64,问就是这么点钱没必要
|
|
||||||
topupGroupRatio := common.GetTopupGroupRatio(user.Group)
|
|
||||||
if topupGroupRatio == 0 {
|
|
||||||
topupGroupRatio = 1
|
|
||||||
}
|
|
||||||
payMoney := amount * constant.Price * topupGroupRatio
|
|
||||||
return payMoney
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMinTopup() int {
|
func GetChargedAmount(count float64, user model.User) float64 {
|
||||||
minTopup := constant.MinTopUp
|
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||||
if !common.DisplayInCurrencyEnabled {
|
if topUpGroupRatio == 0 {
|
||||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
topUpGroupRatio = 1
|
||||||
}
|
}
|
||||||
return minTopup
|
|
||||||
|
return count * topUpGroupRatio
|
||||||
}
|
}
|
||||||
|
|
||||||
func RequestEpay(c *gin.Context) {
|
func RequestPayLink(c *gin.Context) {
|
||||||
var req EpayRequest
|
var req PayRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if req.Amount < getMinTopup() {
|
if !common.PaymentEnabled {
|
||||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.PaymentMethod != "stripe" {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Amount < common.MinTopUp {
|
||||||
|
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Amount > 10000 {
|
||||||
|
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, _ := model.GetUserById(id, false)
|
user, _ := model.GetUserById(id, false)
|
||||||
payMoney := getPayMoney(float64(req.Amount), *user)
|
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||||
if payMoney < 0.01 {
|
|
||||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var payType epay.PurchaseType
|
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), common.RandomString(4))
|
||||||
if req.PaymentMethod == "zfb" {
|
referenceId := "ref_" + common.Sha1(reference)
|
||||||
payType = epay.Alipay
|
|
||||||
}
|
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, int64(req.Amount))
|
||||||
if req.PaymentMethod == "wx" {
|
|
||||||
req.PaymentMethod = "wxpay"
|
|
||||||
payType = epay.WechatPay
|
|
||||||
}
|
|
||||||
callBackAddress := service.GetCallbackAddress()
|
|
||||||
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
|
|
||||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
|
||||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
|
||||||
client := GetEpayClient()
|
|
||||||
if client == nil {
|
|
||||||
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
|
||||||
Type: payType,
|
|
||||||
ServiceTradeNo: "A" + tradeNo,
|
|
||||||
Name: "B" + tradeNo,
|
|
||||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
|
||||||
Device: epay.PC,
|
|
||||||
NotifyUrl: notifyUrl,
|
|
||||||
ReturnUrl: returnUrl,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
amount := req.Amount
|
|
||||||
if !common.DisplayInCurrencyEnabled {
|
|
||||||
amount = amount / int(common.QuotaPerUnit)
|
|
||||||
}
|
|
||||||
topUp := &model.TopUp{
|
topUp := &model.TopUp{
|
||||||
UserId: id,
|
UserId: id,
|
||||||
Amount: amount,
|
Amount: req.Amount,
|
||||||
Money: payMoney,
|
Money: chargedMoney,
|
||||||
TradeNo: "A" + tradeNo,
|
TradeNo: referenceId,
|
||||||
CreateTime: time.Now().Unix(),
|
CreateTime: time.Now().Unix(),
|
||||||
Status: "pending",
|
Status: common.TopUpStatusPending,
|
||||||
}
|
}
|
||||||
err = topUp.Insert()
|
err = topUp.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
|
c.JSON(200, gin.H{
|
||||||
}
|
"message": "success",
|
||||||
|
"data": gin.H{
|
||||||
// tradeNo lock
|
"payLink": payLink,
|
||||||
var orderLocks sync.Map
|
},
|
||||||
var createLock sync.Mutex
|
})
|
||||||
|
|
||||||
// LockOrder 尝试对给定订单号加锁
|
|
||||||
func LockOrder(tradeNo string) {
|
|
||||||
lock, ok := orderLocks.Load(tradeNo)
|
|
||||||
if !ok {
|
|
||||||
createLock.Lock()
|
|
||||||
defer createLock.Unlock()
|
|
||||||
lock, ok = orderLocks.Load(tradeNo)
|
|
||||||
if !ok {
|
|
||||||
lock = new(sync.Mutex)
|
|
||||||
orderLocks.Store(tradeNo, lock)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lock.(*sync.Mutex).Lock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnlockOrder 释放给定订单号的锁
|
|
||||||
func UnlockOrder(tradeNo string) {
|
|
||||||
lock, ok := orderLocks.Load(tradeNo)
|
|
||||||
if ok {
|
|
||||||
lock.(*sync.Mutex).Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func EpayNotify(c *gin.Context) {
|
|
||||||
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
|
|
||||||
r[t] = c.Request.URL.Query().Get(t)
|
|
||||||
return r
|
|
||||||
}, map[string]string{})
|
|
||||||
client := GetEpayClient()
|
|
||||||
if client == nil {
|
|
||||||
log.Println("易支付回调失败 未找到配置信息")
|
|
||||||
_, err := c.Writer.Write([]byte("fail"))
|
|
||||||
if err != nil {
|
|
||||||
log.Println("易支付回调写入失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
verifyInfo, err := client.Verify(params)
|
|
||||||
if err == nil && verifyInfo.VerifyStatus {
|
|
||||||
_, err := c.Writer.Write([]byte("success"))
|
|
||||||
if err != nil {
|
|
||||||
log.Println("易支付回调写入失败")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
_, err := c.Writer.Write([]byte("fail"))
|
|
||||||
if err != nil {
|
|
||||||
log.Println("易支付回调写入失败")
|
|
||||||
}
|
|
||||||
log.Println("易支付回调签名验证失败")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
|
||||||
log.Println(verifyInfo)
|
|
||||||
LockOrder(verifyInfo.ServiceTradeNo)
|
|
||||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
|
||||||
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
|
|
||||||
if topUp == nil {
|
|
||||||
log.Printf("易支付回调未找到订单: %v", verifyInfo)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if topUp.Status == "pending" {
|
|
||||||
topUp.Status = "success"
|
|
||||||
err := topUp.Update()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("易支付回调更新订单失败: %v", topUp)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
|
||||||
//user.Quota += topUp.Amount * 500000
|
|
||||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
|
||||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RequestAmount(c *gin.Context) {
|
func RequestAmount(c *gin.Context) {
|
||||||
@@ -226,17 +142,23 @@ func RequestAmount(c *gin.Context) {
|
|||||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if !common.PaymentEnabled {
|
||||||
if req.Amount < getMinTopup() {
|
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
|
||||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
return
|
||||||
|
}
|
||||||
|
if req.Amount < common.MinTopUp {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, _ := model.GetUserById(id, false)
|
user, _ := model.GetUserById(id, false)
|
||||||
payMoney := getPayMoney(float64(req.Amount), *user)
|
payMoney := GetPayAmount(float64(req.Amount))
|
||||||
if payMoney <= 0.01 {
|
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
c.JSON(200, gin.H{
|
||||||
return
|
"message": "success",
|
||||||
}
|
"data": gin.H{
|
||||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
"payAmount": strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||||
|
"chargedAmount": strconv.FormatFloat(chargedMoney, 'f', 2, 64),
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ func setupLogin(user *model.User, c *gin.Context) {
|
|||||||
session.Set("username", user.Username)
|
session.Set("username", user.Username)
|
||||||
session.Set("role", user.Role)
|
session.Set("role", user.Role)
|
||||||
session.Set("status", user.Status)
|
session.Set("status", user.Status)
|
||||||
|
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
|
||||||
err := session.Save()
|
err := session.Save()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -517,7 +518,7 @@ func UpdateSelf(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteUser(c *gin.Context) {
|
func HardDeleteUser(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -526,7 +527,7 @@ func DeleteUser(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
originUser, err := model.GetUserById(id, false)
|
originUser, err := model.GetUserByIdUnscoped(id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -550,9 +551,23 @@ func DeleteUser(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSelf(c *gin.Context) {
|
func DeleteSelf(c *gin.Context) {
|
||||||
|
if !common.UserSelfDeletionEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "当前设置不允许用户自我删除账号",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, _ := model.GetUserById(id, false)
|
user, _ := model.GetUserById(id, false)
|
||||||
|
|
||||||
|
|||||||
@@ -2,18 +2,17 @@ version: '3.4'
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
new-api:
|
new-api:
|
||||||
image: calciumion/new-api:latest
|
image: pengzhile/new-api:latest
|
||||||
# build: .
|
|
||||||
container_name: new-api
|
container_name: new-api
|
||||||
restart: always
|
restart: always
|
||||||
command: --log-dir /app/logs
|
command: --log-dir /app/logs
|
||||||
ports:
|
ports:
|
||||||
- "3000:3000"
|
- "3000:3000"
|
||||||
volumes:
|
volumes:
|
||||||
- ./data:/data
|
- ./data/new-api:/data
|
||||||
- ./logs:/app/logs
|
- ./logs:/app/logs
|
||||||
environment:
|
environment:
|
||||||
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
- SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||||
- REDIS_CONN_STRING=redis://redis
|
- REDIS_CONN_STRING=redis://redis
|
||||||
- SESSION_SECRET=random_string # 修改为随机字符串
|
- SESSION_SECRET=random_string # 修改为随机字符串
|
||||||
- TZ=Asia/Shanghai
|
- TZ=Asia/Shanghai
|
||||||
@@ -23,13 +22,22 @@ services:
|
|||||||
|
|
||||||
depends_on:
|
depends_on:
|
||||||
- redis
|
- redis
|
||||||
healthcheck:
|
- db
|
||||||
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
image: redis:latest
|
image: redis:latest
|
||||||
container_name: redis
|
container_name: redis
|
||||||
restart: always
|
restart: always
|
||||||
|
|
||||||
|
db:
|
||||||
|
image: mysql:8.2.0
|
||||||
|
container_name: mysql
|
||||||
|
restart: always
|
||||||
|
volumes:
|
||||||
|
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
|
||||||
|
environment:
|
||||||
|
TZ: Asia/Shanghai # 设置时区
|
||||||
|
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
|
||||||
|
MYSQL_USER: newapi # 创建专用用户
|
||||||
|
MYSQL_PASSWORD: '123456' # 设置专用用户密码
|
||||||
|
MYSQL_DATABASE: new-api # 自动创建数据库
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
type OpenAIModelPermission struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int `json:"created"`
|
|
||||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
|
||||||
AllowSampling bool `json:"allow_sampling"`
|
|
||||||
AllowLogprobs bool `json:"allow_logprobs"`
|
|
||||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
|
||||||
AllowView bool `json:"allow_view"`
|
|
||||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
|
||||||
Organization string `json:"organization"`
|
|
||||||
Group *string `json:"group"`
|
|
||||||
IsBlocking bool `json:"is_blocking"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIModels struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int `json:"created"`
|
|
||||||
OwnedBy string `json:"owned_by"`
|
|
||||||
Permission []OpenAIModelPermission `json:"permission"`
|
|
||||||
Root string `json:"root"`
|
|
||||||
Parent *string `json:"parent"`
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
type RerankRequest struct {
|
|
||||||
Documents []any `json:"documents"`
|
|
||||||
Query string `json:"query"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
TopN int `json:"top_n"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type RerankResponseDocument struct {
|
|
||||||
Document any `json:"document"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
RelevanceScore float64 `json:"relevance_score"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type RerankResponse struct {
|
|
||||||
Results []RerankResponseDocument `json:"results"`
|
|
||||||
Usage Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
129
dto/suno.go
129
dto/suno.go
@@ -1,129 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TaskData interface {
|
|
||||||
SunoDataResponse | []SunoDataResponse | string | any
|
|
||||||
}
|
|
||||||
|
|
||||||
type SunoSubmitReq struct {
|
|
||||||
GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"`
|
|
||||||
Prompt string `json:"prompt,omitempty"`
|
|
||||||
Mv string `json:"mv,omitempty"`
|
|
||||||
Title string `json:"title,omitempty"`
|
|
||||||
Tags string `json:"tags,omitempty"`
|
|
||||||
ContinueAt float64 `json:"continue_at,omitempty"`
|
|
||||||
TaskID string `json:"task_id,omitempty"`
|
|
||||||
ContinueClipId string `json:"continue_clip_id,omitempty"`
|
|
||||||
MakeInstrumental bool `json:"make_instrumental"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FetchReq struct {
|
|
||||||
IDs []string `json:"ids"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SunoDataResponse struct {
|
|
||||||
TaskID string `json:"task_id" gorm:"type:varchar(50);index"`
|
|
||||||
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
|
|
||||||
Status string `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed
|
|
||||||
FailReason string `json:"fail_reason"`
|
|
||||||
SubmitTime int64 `json:"submit_time" gorm:"index"`
|
|
||||||
StartTime int64 `json:"start_time" gorm:"index"`
|
|
||||||
FinishTime int64 `json:"finish_time" gorm:"index"`
|
|
||||||
Data json.RawMessage `json:"data" gorm:"type:json"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SunoSong struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
VideoURL string `json:"video_url"`
|
|
||||||
AudioURL string `json:"audio_url"`
|
|
||||||
ImageURL string `json:"image_url"`
|
|
||||||
ImageLargeURL string `json:"image_large_url"`
|
|
||||||
MajorModelVersion string `json:"major_model_version"`
|
|
||||||
ModelName string `json:"model_name"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
Metadata SunoMetadata `json:"metadata"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SunoMetadata struct {
|
|
||||||
Tags string `json:"tags"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"`
|
|
||||||
AudioPromptID interface{} `json:"audio_prompt_id"`
|
|
||||||
Duration interface{} `json:"duration"`
|
|
||||||
ErrorType interface{} `json:"error_type"`
|
|
||||||
ErrorMessage interface{} `json:"error_message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SunoLyrics struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const TaskSuccessCode = "success"
|
|
||||||
|
|
||||||
type TaskResponse[T TaskData] struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Data T `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TaskResponse[T]) IsSuccess() bool {
|
|
||||||
return t.Code == TaskSuccessCode
|
|
||||||
}
|
|
||||||
|
|
||||||
type TaskDto struct {
|
|
||||||
TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
|
|
||||||
Action string `json:"action"` // 任务类型, song, lyrics, description-mode
|
|
||||||
Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed
|
|
||||||
FailReason string `json:"fail_reason"`
|
|
||||||
SubmitTime int64 `json:"submit_time"`
|
|
||||||
StartTime int64 `json:"start_time"`
|
|
||||||
FinishTime int64 `json:"finish_time"`
|
|
||||||
Progress string `json:"progress"`
|
|
||||||
Data json.RawMessage `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SunoGoAPISubmitReq struct {
|
|
||||||
CustomMode bool `json:"custom_mode"`
|
|
||||||
|
|
||||||
Input SunoGoAPISubmitReqInput `json:"input"`
|
|
||||||
|
|
||||||
NotifyHook string `json:"notify_hook,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SunoGoAPISubmitReqInput struct {
|
|
||||||
GptDescriptionPrompt string `json:"gpt_description_prompt"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Mv string `json:"mv"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
Tags string `json:"tags"`
|
|
||||||
ContinueAt float64 `json:"continue_at"`
|
|
||||||
TaskID string `json:"task_id"`
|
|
||||||
ContinueClipId string `json:"continue_clip_id"`
|
|
||||||
MakeInstrumental bool `json:"make_instrumental"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GoAPITaskResponse[T any] struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Data T `json:"data"`
|
|
||||||
ErrorMessage string `json:"error_message,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GoAPITaskResponseData struct {
|
|
||||||
TaskID string `json:"task_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GoAPIFetchResponseData struct {
|
|
||||||
TaskID string `json:"task_id"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Input string `json:"input"`
|
|
||||||
Clips map[string]SunoSong `json:"clips"`
|
|
||||||
}
|
|
||||||
10
dto/task.go
10
dto/task.go
@@ -1,10 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
type TaskError struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Data any `json:"data"`
|
|
||||||
StatusCode int `json:"-"`
|
|
||||||
LocalError bool `json:"-"`
|
|
||||||
Error error `json:"-"`
|
|
||||||
}
|
|
||||||
@@ -11,7 +11,6 @@ type GeneralOpenAIRequest struct {
|
|||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
Prompt any `json:"prompt,omitempty"`
|
Prompt any `json:"prompt,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
|
||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
@@ -44,10 +43,6 @@ type OpenAIFunction struct {
|
|||||||
Parameters any `json:"parameters,omitempty"`
|
Parameters any `json:"parameters,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type StreamOptions struct {
|
|
||||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
||||||
return int64(r.MaxTokens)
|
return int64(r.MaxTokens)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -102,12 +102,10 @@ type ChatCompletionsStreamResponse struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
SystemFingerprint *string `json:"system_fingerprint"`
|
SystemFingerprint *string `json:"system_fingerprint"`
|
||||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||||
Usage *Usage `json:"usage"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseSimple struct {
|
type ChatCompletionsStreamResponseSimple struct {
|
||||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||||
Usage *Usage `json:"usage"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionsStreamResponse struct {
|
type CompletionsStreamResponse struct {
|
||||||
|
|||||||
9
go.mod
9
go.mod
@@ -4,7 +4,6 @@ module one-api
|
|||||||
go 1.18
|
go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Calcium-Ion/go-epay v0.0.2
|
|
||||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
||||||
github.com/aws/aws-sdk-go-v2 v1.26.1
|
github.com/aws/aws-sdk-go-v2 v1.26.1
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||||
@@ -17,13 +16,14 @@ require (
|
|||||||
github.com/go-playground/validator/v10 v10.19.0
|
github.com/go-playground/validator/v10 v10.19.0
|
||||||
github.com/go-redis/redis/v8 v8.11.5
|
github.com/go-redis/redis/v8 v8.11.5
|
||||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.3.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
github.com/jinzhu/copier v0.4.0
|
github.com/jinzhu/copier v0.4.0
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/pkoukk/tiktoken-go v0.1.7
|
github.com/pkoukk/tiktoken-go v0.1.6
|
||||||
github.com/samber/lo v1.39.0
|
github.com/samber/lo v1.39.0
|
||||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||||
|
github.com/stripe/stripe-go/v76 v76.21.0
|
||||||
golang.org/x/crypto v0.21.0
|
golang.org/x/crypto v0.21.0
|
||||||
golang.org/x/image v0.15.0
|
golang.org/x/image v0.15.0
|
||||||
gorm.io/driver/mysql v1.4.3
|
gorm.io/driver/mysql v1.4.3
|
||||||
@@ -42,7 +42,7 @@ require (
|
|||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/dlclark/regexp2 v1.11.0 // indirect
|
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
@@ -64,7 +64,6 @@ require (
|
|||||||
github.com/leodido/go-urn v1.4.0 // indirect
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
||||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||||
|
|||||||
20
go.sum
20
go.sum
@@ -1,5 +1,3 @@
|
|||||||
github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
|
|
||||||
github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
|
||||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
||||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||||
@@ -32,8 +30,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
|
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||||
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
@@ -81,8 +79,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
|||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
|
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||||
@@ -131,8 +129,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
|||||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
|
||||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@@ -148,8 +144,8 @@ github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNc
|
|||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
|
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
|
||||||
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||||
@@ -171,6 +167,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
|||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
|
github.com/stripe/stripe-go/v76 v76.21.0 h1:O3GHImHS4oUI3qWMOClHN3zAQF5/oswS/NB7leV1fsU=
|
||||||
|
github.com/stripe/stripe-go/v76 v76.21.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4=
|
||||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||||
@@ -196,6 +194,7 @@ golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSO
|
|||||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
|
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||||
@@ -203,6 +202,7 @@ golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
|||||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
|||||||
11
main.go
11
main.go
@@ -89,14 +89,9 @@ func main() {
|
|||||||
}
|
}
|
||||||
go controller.AutomaticallyTestChannels(frequency)
|
go controller.AutomaticallyTestChannels(frequency)
|
||||||
}
|
}
|
||||||
if common.IsMasterNode {
|
common.SafeGoroutine(func() {
|
||||||
common.SafeGoroutine(func() {
|
controller.UpdateMidjourneyTaskBulk()
|
||||||
controller.UpdateMidjourneyTaskBulk()
|
})
|
||||||
})
|
|
||||||
common.SafeGoroutine(func() {
|
|
||||||
controller.UpdateTaskBulk()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
||||||
common.BatchUpdateEnabled = true
|
common.BatchUpdateEnabled = true
|
||||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ func authHelper(c *gin.Context, minRole int) {
|
|||||||
role := session.Get("role")
|
role := session.Get("role")
|
||||||
id := session.Get("id")
|
id := session.Get("id")
|
||||||
status := session.Get("status")
|
status := session.Get("status")
|
||||||
|
linuxDoEnable := session.Get("linuxdo_enable")
|
||||||
if username == nil {
|
if username == nil {
|
||||||
// Check access token
|
// Check access token
|
||||||
accessToken := c.Request.Header.Get("Authorization")
|
accessToken := c.Request.Header.Get("Authorization")
|
||||||
@@ -33,6 +34,7 @@ func authHelper(c *gin.Context, minRole int) {
|
|||||||
role = user.Role
|
role = user.Role
|
||||||
id = user.Id
|
id = user.Id
|
||||||
status = user.Status
|
status = user.Status
|
||||||
|
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
|
||||||
} else {
|
} else {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -50,6 +52,14 @@ func authHelper(c *gin.Context, minRole int) {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户 LINUX DO 信任等级不足",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
if role.(int) < minRole {
|
if role.(int) < minRole {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -64,17 +74,6 @@ func authHelper(c *gin.Context, minRole int) {
|
|||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TryUserAuth() func(c *gin.Context) {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
session := sessions.Default(c)
|
|
||||||
id := session.Get("id")
|
|
||||||
if id != nil {
|
|
||||||
c.Set("id", id)
|
|
||||||
}
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func UserAuth() func(c *gin.Context) {
|
func UserAuth() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
authHelper(c, common.RoleCommonUser)
|
authHelper(c, common.RoleCommonUser)
|
||||||
@@ -123,6 +122,15 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
|
||||||
|
if err != nil {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !linuxDoEnabled {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
|
||||||
|
return
|
||||||
|
}
|
||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
c.Set("token_id", token.Id)
|
c.Set("token_id", token.Id)
|
||||||
c.Set("token_name", token.Name)
|
c.Set("token_name", token.Name)
|
||||||
|
|||||||
@@ -125,17 +125,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
modelRequest.Model = midjourneyModel
|
modelRequest.Model = midjourneyModel
|
||||||
}
|
}
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
} else if strings.Contains(c.Request.URL.Path, "/suno/") {
|
|
||||||
relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path)
|
|
||||||
if relayMode == relayconstant.RelayModeSunoFetch ||
|
|
||||||
relayMode == relayconstant.RelayModeSunoFetchByID {
|
|
||||||
shouldSelectChannel = false
|
|
||||||
} else {
|
|
||||||
modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
|
|
||||||
modelRequest.Model = modelName
|
|
||||||
}
|
|
||||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
|
||||||
c.Set("relay_mode", relayMode)
|
|
||||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
@@ -178,7 +167,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set("channel", channel.Type)
|
c.Set("channel", channel.Type)
|
||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
c.Set("channel_type", channel.Type)
|
|
||||||
ban := true
|
ban := true
|
||||||
// parse *int to bool
|
// parse *int to bool
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||||
|
|||||||
@@ -29,13 +29,6 @@ func GetGroupModels(group string) []string {
|
|||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetEnabledModels() []string {
|
|
||||||
var models []string
|
|
||||||
// Find distinct models
|
|
||||||
DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models)
|
|
||||||
return models
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPriority(group string, model string, retry int) (int, error) {
|
func getPriority(group string, model string, retry int) (int, error) {
|
||||||
groupCol := "`group`"
|
groupCol := "`group`"
|
||||||
trueVal := "1"
|
trueVal := "1"
|
||||||
@@ -56,11 +49,6 @@ func getPriority(group string, model string, retry int) (int, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(priorities) == 0 {
|
|
||||||
// 如果没有查询到优先级,则返回错误
|
|
||||||
return 0, errors.New("数据库一致性被破坏")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 确定要使用的优先级
|
// 确定要使用的优先级
|
||||||
var priorityToUse int
|
var priorityToUse int
|
||||||
if retry >= len(priorities) {
|
if retry >= len(priorities) {
|
||||||
@@ -204,7 +192,7 @@ func FixAbility() (int, error) {
|
|||||||
|
|
||||||
// Use channelIds to find channel not in abilities table
|
// Use channelIds to find channel not in abilities table
|
||||||
var abilityChannelIds []int
|
var abilityChannelIds []int
|
||||||
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
|
err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
|
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ func SyncTokenCache(frequency int) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 如果数据库中存在,先检查redis
|
// 如果数据库中存在,先检查redis
|
||||||
_, err = common.RedisGet(fmt.Sprintf("token:%s", key))
|
_, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 如果redis中不存在,则跳过
|
// 如果redis中不存在,则跳过
|
||||||
continue
|
continue
|
||||||
@@ -205,6 +205,30 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
|||||||
return userEnabled, err
|
return userEnabled, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return IsLinuxDoEnabled(userId)
|
||||||
|
}
|
||||||
|
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
|
||||||
|
if err == nil {
|
||||||
|
return enabled == "1", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
enabled = "0"
|
||||||
|
if linuxDoEnabled {
|
||||||
|
enabled = "1"
|
||||||
|
}
|
||||||
|
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("Redis set linuxdo enabled error: " + err.Error())
|
||||||
|
}
|
||||||
|
return linuxDoEnabled, err
|
||||||
|
}
|
||||||
|
|
||||||
var group2model2channels map[string]map[string][]*Channel
|
var group2model2channels map[string]map[string][]*Channel
|
||||||
var channelsIDM map[int]*Channel
|
var channelsIDM map[int]*Channel
|
||||||
var channelSyncLock sync.RWMutex
|
var channelSyncLock sync.RWMutex
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
@@ -31,38 +29,6 @@ type Channel struct {
|
|||||||
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
|
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||||
OtherInfo string `json:"other_info"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (channel *Channel) GetModels() []string {
|
|
||||||
if channel.Models == "" {
|
|
||||||
return []string{}
|
|
||||||
}
|
|
||||||
return strings.Split(strings.Trim(channel.Models, ","), ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (channel *Channel) GetOtherInfo() map[string]interface{} {
|
|
||||||
otherInfo := make(map[string]interface{})
|
|
||||||
if channel.OtherInfo != "" {
|
|
||||||
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("failed to unmarshal other info: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return otherInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
|
|
||||||
otherInfoBytes, err := json.Marshal(otherInfo)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("failed to marshal other info: " + err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
channel.OtherInfo = string(otherInfoBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (channel *Channel) Save() error {
|
|
||||||
return DB.Save(channel).Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
||||||
@@ -247,31 +213,15 @@ func (channel *Channel) Delete() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannelStatusById(id int, status int, reason string) {
|
func UpdateChannelStatusById(id int, status int) {
|
||||||
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to update ability status: " + err.Error())
|
common.SysError("failed to update ability status: " + err.Error())
|
||||||
}
|
}
|
||||||
channel, err := GetChannelById(id, true)
|
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// find channel by id error, directly update status
|
common.SysError("failed to update channel status: " + err.Error())
|
||||||
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("failed to update channel status: " + err.Error())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// find channel by id success, update status and other info
|
|
||||||
info := channel.GetOtherInfo()
|
|
||||||
info["status_reason"] = reason
|
|
||||||
info["status_time"] = common.GetTimestamp()
|
|
||||||
channel.SetOtherInfo(info)
|
|
||||||
channel.Status = status
|
|
||||||
err = channel.Save()
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("failed to update channel status: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateChannelUsedQuota(id int, quota int) {
|
func UpdateChannelUsedQuota(id int, quota int) {
|
||||||
|
|||||||
36
model/log.go
36
model/log.go
@@ -24,7 +24,6 @@ type Log struct {
|
|||||||
IsStream bool `json:"is_stream" gorm:"default:false"`
|
IsStream bool `json:"is_stream" gorm:"default:false"`
|
||||||
ChannelId int `json:"channel" gorm:"index"`
|
ChannelId int `json:"channel" gorm:"index"`
|
||||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||||
Other string `json:"other"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -36,7 +35,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func GetLogByKey(key string) (logs []*Log, err error) {
|
func GetLogByKey(key string) (logs []*Log, err error) {
|
||||||
err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
|
err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.Split(key, "-")[1]).Find(&logs).Error
|
||||||
return logs, err
|
return logs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,13 +57,12 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool, other map[string]interface{}) {
|
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool) {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||||
if !common.LogConsumeEnabled {
|
if !common.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
username, _ := CacheGetUsername(userId)
|
username, _ := CacheGetUsername(userId)
|
||||||
otherStr := common.MapToJsonStr(other)
|
|
||||||
log := &Log{
|
log := &Log{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
Username: username,
|
Username: username,
|
||||||
@@ -80,7 +78,6 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
|||||||
TokenId: tokenId,
|
TokenId: tokenId,
|
||||||
UseTime: useTimeSeconds,
|
UseTime: useTimeSeconds,
|
||||||
IsStream: isStream,
|
IsStream: isStream,
|
||||||
Other: otherStr,
|
|
||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -93,7 +90,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
|
||||||
var tx *gorm.DB
|
var tx *gorm.DB
|
||||||
if logType == LogTypeUnknown {
|
if logType == LogTypeUnknown {
|
||||||
tx = DB
|
tx = DB
|
||||||
@@ -118,11 +115,17 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
if channel != 0 {
|
if channel != 0 {
|
||||||
tx = tx.Where("channel_id = ?", channel)
|
tx = tx.Where("channel_id = ?", channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = tx.Model(&Log{}).Count(&total).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||||
return logs, err
|
return logs, total, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
|
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
|
||||||
var tx *gorm.DB
|
var tx *gorm.DB
|
||||||
if logType == LogTypeUnknown {
|
if logType == LogTypeUnknown {
|
||||||
tx = DB.Where("user_id = ?", userId)
|
tx = DB.Where("user_id = ?", userId)
|
||||||
@@ -141,17 +144,14 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
|||||||
if endTimestamp != 0 {
|
if endTimestamp != 0 {
|
||||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||||
}
|
}
|
||||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
|
|
||||||
for i := range logs {
|
err = tx.Model(&Log{}).Count(&total).Error
|
||||||
var otherMap map[string]interface{}
|
if err != nil {
|
||||||
otherMap = common.StrToMap(logs[i].Other)
|
return nil, 0, err
|
||||||
if otherMap != nil {
|
|
||||||
// delete admin
|
|
||||||
delete(otherMap, "admin_info")
|
|
||||||
}
|
|
||||||
logs[i].Other = common.MapToJsonStr(otherMap)
|
|
||||||
}
|
}
|
||||||
return logs, err
|
|
||||||
|
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
|
||||||
|
return logs, total, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
||||||
|
|||||||
@@ -86,19 +86,19 @@ func InitDB() (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
|
||||||
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
|
||||||
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
|
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
|
||||||
|
|
||||||
if !common.IsMasterNode {
|
if !common.IsMasterNode {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
//if common.UsingMySQL {
|
if common.UsingMySQL {
|
||||||
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
||||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
|
_, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
|
||||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
|
_, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
|
||||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
|
_, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
|
||||||
//}
|
}
|
||||||
common.SysLog("database migration started")
|
common.SysLog("database migration started")
|
||||||
err = db.AutoMigrate(&Channel{})
|
err = db.AutoMigrate(&Channel{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -140,10 +140,6 @@ func InitDB() (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(&Task{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
common.SysLog("database migrated")
|
common.SysLog("database migrated")
|
||||||
err = createRootAccountIfNeed()
|
err = createRootAccountIfNeed()
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -31,17 +31,18 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
|
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
|
||||||
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
|
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
|
||||||
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
|
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
|
||||||
|
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
|
||||||
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
|
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
|
||||||
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
|
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
|
||||||
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
||||||
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
|
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
|
||||||
|
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
|
||||||
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
||||||
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
|
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
|
||||||
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
||||||
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
||||||
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
|
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
|
||||||
common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
|
common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
|
||||||
common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
|
|
||||||
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
|
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
|
||||||
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
||||||
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
|
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
|
||||||
@@ -60,17 +61,18 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["SystemName"] = common.SystemName
|
common.OptionMap["SystemName"] = common.SystemName
|
||||||
common.OptionMap["Logo"] = common.Logo
|
common.OptionMap["Logo"] = common.Logo
|
||||||
common.OptionMap["ServerAddress"] = ""
|
common.OptionMap["ServerAddress"] = ""
|
||||||
common.OptionMap["WorkerUrl"] = constant.WorkerUrl
|
common.OptionMap["StripeApiSecret"] = common.StripeApiSecret
|
||||||
common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey
|
common.OptionMap["StripeWebhookSecret"] = common.StripeWebhookSecret
|
||||||
common.OptionMap["PayAddress"] = ""
|
common.OptionMap["StripePriceId"] = common.StripePriceId
|
||||||
common.OptionMap["CustomCallbackAddress"] = ""
|
common.OptionMap["PaymentEnabled"] = strconv.FormatBool(common.PaymentEnabled)
|
||||||
common.OptionMap["EpayId"] = ""
|
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64)
|
||||||
common.OptionMap["EpayKey"] = ""
|
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
|
||||||
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64)
|
|
||||||
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
|
|
||||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||||
common.OptionMap["GitHubClientId"] = ""
|
common.OptionMap["GitHubClientId"] = ""
|
||||||
common.OptionMap["GitHubClientSecret"] = ""
|
common.OptionMap["GitHubClientSecret"] = ""
|
||||||
|
common.OptionMap["LinuxDoClientId"] = ""
|
||||||
|
common.OptionMap["LinuxDoClientSecret"] = ""
|
||||||
|
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
|
||||||
common.OptionMap["TelegramBotToken"] = ""
|
common.OptionMap["TelegramBotToken"] = ""
|
||||||
common.OptionMap["TelegramBotName"] = ""
|
common.OptionMap["TelegramBotName"] = ""
|
||||||
common.OptionMap["WeChatServerAddress"] = ""
|
common.OptionMap["WeChatServerAddress"] = ""
|
||||||
@@ -86,7 +88,6 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||||
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
||||||
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
||||||
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
|
|
||||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||||
common.OptionMap["ChatLink"] = common.ChatLink
|
common.OptionMap["ChatLink"] = common.ChatLink
|
||||||
common.OptionMap["ChatLink2"] = common.ChatLink2
|
common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||||
@@ -99,7 +100,6 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
||||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||||
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
|
|
||||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||||
@@ -173,6 +173,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
common.EmailVerificationEnabled = boolValue
|
common.EmailVerificationEnabled = boolValue
|
||||||
case "GitHubOAuthEnabled":
|
case "GitHubOAuthEnabled":
|
||||||
common.GitHubOAuthEnabled = boolValue
|
common.GitHubOAuthEnabled = boolValue
|
||||||
|
case "LinuxDoOAuthEnabled":
|
||||||
|
common.LinuxDoOAuthEnabled = boolValue
|
||||||
case "WeChatAuthEnabled":
|
case "WeChatAuthEnabled":
|
||||||
common.WeChatAuthEnabled = boolValue
|
common.WeChatAuthEnabled = boolValue
|
||||||
case "TelegramOAuthEnabled":
|
case "TelegramOAuthEnabled":
|
||||||
@@ -181,6 +183,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
common.TurnstileCheckEnabled = boolValue
|
common.TurnstileCheckEnabled = boolValue
|
||||||
case "RegisterEnabled":
|
case "RegisterEnabled":
|
||||||
common.RegisterEnabled = boolValue
|
common.RegisterEnabled = boolValue
|
||||||
|
case "UserSelfDeletionEnabled":
|
||||||
|
common.UserSelfDeletionEnabled = boolValue
|
||||||
case "EmailDomainRestrictionEnabled":
|
case "EmailDomainRestrictionEnabled":
|
||||||
common.EmailDomainRestrictionEnabled = boolValue
|
common.EmailDomainRestrictionEnabled = boolValue
|
||||||
case "EmailAliasRestrictionEnabled":
|
case "EmailAliasRestrictionEnabled":
|
||||||
@@ -197,8 +201,6 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
common.DisplayTokenStatEnabled = boolValue
|
common.DisplayTokenStatEnabled = boolValue
|
||||||
case "DrawingEnabled":
|
case "DrawingEnabled":
|
||||||
common.DrawingEnabled = boolValue
|
common.DrawingEnabled = boolValue
|
||||||
case "TaskEnabled":
|
|
||||||
common.TaskEnabled = boolValue
|
|
||||||
case "DataExportEnabled":
|
case "DataExportEnabled":
|
||||||
common.DataExportEnabled = boolValue
|
common.DataExportEnabled = boolValue
|
||||||
case "DefaultCollapseSidebar":
|
case "DefaultCollapseSidebar":
|
||||||
@@ -211,8 +213,6 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
constant.MjModeClearEnabled = boolValue
|
constant.MjModeClearEnabled = boolValue
|
||||||
case "MjForwardUrlEnabled":
|
case "MjForwardUrlEnabled":
|
||||||
constant.MjForwardUrlEnabled = boolValue
|
constant.MjForwardUrlEnabled = boolValue
|
||||||
case "MjActionCheckSuccessEnabled":
|
|
||||||
constant.MjActionCheckSuccessEnabled = boolValue
|
|
||||||
case "CheckSensitiveEnabled":
|
case "CheckSensitiveEnabled":
|
||||||
constant.CheckSensitiveEnabled = boolValue
|
constant.CheckSensitiveEnabled = boolValue
|
||||||
case "CheckSensitiveOnPromptEnabled":
|
case "CheckSensitiveOnPromptEnabled":
|
||||||
@@ -240,29 +240,31 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
case "SMTPToken":
|
case "SMTPToken":
|
||||||
common.SMTPToken = value
|
common.SMTPToken = value
|
||||||
case "ServerAddress":
|
case "ServerAddress":
|
||||||
constant.ServerAddress = value
|
common.ServerAddress = value
|
||||||
case "WorkerUrl":
|
case "StripeApiSecret":
|
||||||
constant.WorkerUrl = value
|
common.StripeApiSecret = value
|
||||||
case "WorkerValidKey":
|
case "StripeWebhookSecret":
|
||||||
constant.WorkerValidKey = value
|
common.StripeWebhookSecret = value
|
||||||
case "PayAddress":
|
case "StripePriceId":
|
||||||
constant.PayAddress = value
|
common.StripePriceId = value
|
||||||
case "CustomCallbackAddress":
|
case "PaymentEnabled":
|
||||||
constant.CustomCallbackAddress = value
|
common.PaymentEnabled, _ = strconv.ParseBool(value)
|
||||||
case "EpayId":
|
case "StripeUnitPrice":
|
||||||
constant.EpayId = value
|
common.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||||
case "EpayKey":
|
|
||||||
constant.EpayKey = value
|
|
||||||
case "Price":
|
|
||||||
constant.Price, _ = strconv.ParseFloat(value, 64)
|
|
||||||
case "MinTopUp":
|
case "MinTopUp":
|
||||||
constant.MinTopUp, _ = strconv.Atoi(value)
|
common.MinTopUp, _ = strconv.Atoi(value)
|
||||||
case "TopupGroupRatio":
|
case "TopupGroupRatio":
|
||||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||||
case "GitHubClientId":
|
case "GitHubClientId":
|
||||||
common.GitHubClientId = value
|
common.GitHubClientId = value
|
||||||
case "GitHubClientSecret":
|
case "GitHubClientSecret":
|
||||||
common.GitHubClientSecret = value
|
common.GitHubClientSecret = value
|
||||||
|
case "LinuxDoClientId":
|
||||||
|
common.LinuxDoClientId = value
|
||||||
|
case "LinuxDoClientSecret":
|
||||||
|
common.LinuxDoClientSecret = value
|
||||||
|
case "LinuxDoMinLevel":
|
||||||
|
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
|
||||||
case "Footer":
|
case "Footer":
|
||||||
common.Footer = value
|
common.Footer = value
|
||||||
case "SystemName":
|
case "SystemName":
|
||||||
@@ -303,8 +305,6 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
err = common.UpdateModelRatioByJSONString(value)
|
err = common.UpdateModelRatioByJSONString(value)
|
||||||
case "GroupRatio":
|
case "GroupRatio":
|
||||||
err = common.UpdateGroupRatioByJSONString(value)
|
err = common.UpdateGroupRatioByJSONString(value)
|
||||||
case "CompletionRatio":
|
|
||||||
err = common.UpdateCompletionRatioByJSONString(value)
|
|
||||||
case "ModelPrice":
|
case "ModelPrice":
|
||||||
err = common.UpdateModelPriceByJSONString(value)
|
err = common.UpdateModelPriceByJSONString(value)
|
||||||
case "TopUpLink":
|
case "TopUpLink":
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"one-api/common"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Pricing struct {
|
|
||||||
Available bool `json:"available"`
|
|
||||||
ModelName string `json:"model_name"`
|
|
||||||
QuotaType int `json:"quota_type"`
|
|
||||||
ModelRatio float64 `json:"model_ratio"`
|
|
||||||
ModelPrice float64 `json:"model_price"`
|
|
||||||
OwnerBy string `json:"owner_by"`
|
|
||||||
CompletionRatio float64 `json:"completion_ratio"`
|
|
||||||
EnableGroup []string `json:"enable_group,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
pricingMap []Pricing
|
|
||||||
lastGetPricingTime time.Time
|
|
||||||
updatePricingLock sync.Mutex
|
|
||||||
)
|
|
||||||
|
|
||||||
func GetPricing(group string) []Pricing {
|
|
||||||
updatePricingLock.Lock()
|
|
||||||
defer updatePricingLock.Unlock()
|
|
||||||
|
|
||||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
|
||||||
updatePricing()
|
|
||||||
}
|
|
||||||
if group != "" {
|
|
||||||
userPricingMap := make([]Pricing, 0)
|
|
||||||
models := GetGroupModels(group)
|
|
||||||
for _, pricing := range pricingMap {
|
|
||||||
if !common.StringsContains(models, pricing.ModelName) {
|
|
||||||
pricing.Available = false
|
|
||||||
}
|
|
||||||
userPricingMap = append(userPricingMap, pricing)
|
|
||||||
}
|
|
||||||
return userPricingMap
|
|
||||||
}
|
|
||||||
return pricingMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func updatePricing() {
|
|
||||||
//modelRatios := common.GetModelRatios()
|
|
||||||
enabledModels := GetEnabledModels()
|
|
||||||
allModels := make(map[string]int)
|
|
||||||
for i, model := range enabledModels {
|
|
||||||
allModels[model] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
pricingMap = make([]Pricing, 0)
|
|
||||||
for model, _ := range allModels {
|
|
||||||
pricing := Pricing{
|
|
||||||
Available: true,
|
|
||||||
ModelName: model,
|
|
||||||
}
|
|
||||||
modelPrice, findPrice := common.GetModelPrice(model, false)
|
|
||||||
if findPrice {
|
|
||||||
pricing.ModelPrice = modelPrice
|
|
||||||
pricing.QuotaType = 1
|
|
||||||
} else {
|
|
||||||
pricing.ModelRatio = common.GetModelRatio(model)
|
|
||||||
pricing.CompletionRatio = common.GetCompletionRatio(model)
|
|
||||||
pricing.QuotaType = 0
|
|
||||||
}
|
|
||||||
pricingMap = append(pricingMap, pricing)
|
|
||||||
}
|
|
||||||
lastGetPricingTime = time.Now()
|
|
||||||
}
|
|
||||||
@@ -78,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.New("兑换失败," + err.Error())
|
return 0, errors.New("兑换失败," + err.Error())
|
||||||
}
|
}
|
||||||
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
|
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
|
||||||
return redemption.Quota, nil
|
return redemption.Quota, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
304
model/task.go
304
model/task.go
@@ -1,304 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql/driver"
|
|
||||||
"encoding/json"
|
|
||||||
"one-api/constant"
|
|
||||||
commonRelay "one-api/relay/common"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TaskStatus string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TaskStatusNotStart TaskStatus = "NOT_START"
|
|
||||||
TaskStatusSubmitted = "SUBMITTED"
|
|
||||||
TaskStatusQueued = "QUEUED"
|
|
||||||
TaskStatusInProgress = "IN_PROGRESS"
|
|
||||||
TaskStatusFailure = "FAILURE"
|
|
||||||
TaskStatusSuccess = "SUCCESS"
|
|
||||||
TaskStatusUnknown = "UNKNOWN"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Task struct {
|
|
||||||
ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
|
|
||||||
CreatedAt int64 `json:"created_at" gorm:"index"`
|
|
||||||
UpdatedAt int64 `json:"updated_at"`
|
|
||||||
TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id,不一定有/ song id\ Task id
|
|
||||||
Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
|
|
||||||
UserId int `json:"user_id" gorm:"index"`
|
|
||||||
ChannelId int `json:"channel_id" gorm:"index"`
|
|
||||||
Quota int `json:"quota"`
|
|
||||||
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
|
|
||||||
Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态
|
|
||||||
FailReason string `json:"fail_reason"`
|
|
||||||
SubmitTime int64 `json:"submit_time" gorm:"index"`
|
|
||||||
StartTime int64 `json:"start_time" gorm:"index"`
|
|
||||||
FinishTime int64 `json:"finish_time" gorm:"index"`
|
|
||||||
Progress string `json:"progress" gorm:"type:varchar(20);index"`
|
|
||||||
Properties Properties `json:"properties" gorm:"type:json"`
|
|
||||||
|
|
||||||
Data json.RawMessage `json:"data" gorm:"type:json"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) SetData(data any) {
|
|
||||||
b, _ := json.Marshal(data)
|
|
||||||
t.Data = json.RawMessage(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) GetData(v any) error {
|
|
||||||
err := json.Unmarshal(t.Data, &v)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
type Properties struct {
|
|
||||||
Input string `json:"input"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Properties) Scan(val interface{}) error {
|
|
||||||
bytesValue, _ := val.([]byte)
|
|
||||||
return json.Unmarshal(bytesValue, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Properties) Value() (driver.Value, error) {
|
|
||||||
return json.Marshal(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
|
|
||||||
type SyncTaskQueryParams struct {
|
|
||||||
Platform constant.TaskPlatform
|
|
||||||
ChannelID string
|
|
||||||
TaskID string
|
|
||||||
UserID string
|
|
||||||
Action string
|
|
||||||
Status string
|
|
||||||
StartTimestamp int64
|
|
||||||
EndTimestamp int64
|
|
||||||
UserIDs []int
|
|
||||||
}
|
|
||||||
|
|
||||||
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
|
|
||||||
t := &Task{
|
|
||||||
UserId: relayInfo.UserId,
|
|
||||||
SubmitTime: time.Now().Unix(),
|
|
||||||
Status: TaskStatusNotStart,
|
|
||||||
Progress: "0%",
|
|
||||||
ChannelId: relayInfo.ChannelId,
|
|
||||||
Platform: platform,
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
|
|
||||||
var tasks []*Task
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// 初始化查询构建器
|
|
||||||
query := DB.Where("user_id = ?", userId)
|
|
||||||
|
|
||||||
if queryParams.TaskID != "" {
|
|
||||||
query = query.Where("task_id = ?", queryParams.TaskID)
|
|
||||||
}
|
|
||||||
if queryParams.Action != "" {
|
|
||||||
query = query.Where("action = ?", queryParams.Action)
|
|
||||||
}
|
|
||||||
if queryParams.Status != "" {
|
|
||||||
query = query.Where("status = ?", queryParams.Status)
|
|
||||||
}
|
|
||||||
if queryParams.Platform != "" {
|
|
||||||
query = query.Where("platform = ?", queryParams.Platform)
|
|
||||||
}
|
|
||||||
if queryParams.StartTimestamp != 0 {
|
|
||||||
// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
|
|
||||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
|
||||||
}
|
|
||||||
if queryParams.EndTimestamp != 0 {
|
|
||||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取数据
|
|
||||||
err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return tasks
|
|
||||||
}
|
|
||||||
|
|
||||||
func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
|
|
||||||
var tasks []*Task
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// 初始化查询构建器
|
|
||||||
query := DB
|
|
||||||
|
|
||||||
// 添加过滤条件
|
|
||||||
if queryParams.ChannelID != "" {
|
|
||||||
query = query.Where("channel_id = ?", queryParams.ChannelID)
|
|
||||||
}
|
|
||||||
if queryParams.Platform != "" {
|
|
||||||
query = query.Where("platform = ?", queryParams.Platform)
|
|
||||||
}
|
|
||||||
if queryParams.UserID != "" {
|
|
||||||
query = query.Where("user_id = ?", queryParams.UserID)
|
|
||||||
}
|
|
||||||
if len(queryParams.UserIDs) != 0 {
|
|
||||||
query = query.Where("user_id in (?)", queryParams.UserIDs)
|
|
||||||
}
|
|
||||||
if queryParams.TaskID != "" {
|
|
||||||
query = query.Where("task_id = ?", queryParams.TaskID)
|
|
||||||
}
|
|
||||||
if queryParams.Action != "" {
|
|
||||||
query = query.Where("action = ?", queryParams.Action)
|
|
||||||
}
|
|
||||||
if queryParams.Status != "" {
|
|
||||||
query = query.Where("status = ?", queryParams.Status)
|
|
||||||
}
|
|
||||||
if queryParams.StartTimestamp != 0 {
|
|
||||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
|
||||||
}
|
|
||||||
if queryParams.EndTimestamp != 0 {
|
|
||||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取数据
|
|
||||||
err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return tasks
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAllUnFinishSyncTasks(limit int) []*Task {
|
|
||||||
var tasks []*Task
|
|
||||||
var err error
|
|
||||||
// get all tasks progress is not 100%
|
|
||||||
err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return tasks
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
|
|
||||||
if taskId == "" {
|
|
||||||
return nil, false, nil
|
|
||||||
}
|
|
||||||
var task *Task
|
|
||||||
var err error
|
|
||||||
err = DB.Where("task_id = ?", taskId).First(&task).Error
|
|
||||||
exist, err := RecordExist(err)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
return task, exist, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
|
|
||||||
if taskId == "" {
|
|
||||||
return nil, false, nil
|
|
||||||
}
|
|
||||||
var task *Task
|
|
||||||
var err error
|
|
||||||
err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
|
|
||||||
First(&task).Error
|
|
||||||
exist, err := RecordExist(err)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, err
|
|
||||||
}
|
|
||||||
return task, exist, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
|
|
||||||
if len(taskIds) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
var task []*Task
|
|
||||||
var err error
|
|
||||||
err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
|
|
||||||
Find(&task).Error
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return task, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TaskUpdateProgress(id int64, progress string) error {
|
|
||||||
return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (Task *Task) Insert() error {
|
|
||||||
var err error
|
|
||||||
err = DB.Create(Task).Error
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (Task *Task) Update() error {
|
|
||||||
var err error
|
|
||||||
err = DB.Save(Task).Error
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
|
|
||||||
if len(TaskIds) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return DB.Model(&Task{}).
|
|
||||||
Where("task_id in (?)", TaskIds).
|
|
||||||
Updates(params).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
|
|
||||||
if len(taskIDs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return DB.Model(&Task{}).
|
|
||||||
Where("id in (?)", taskIDs).
|
|
||||||
Updates(params).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
|
|
||||||
if len(ids) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return DB.Model(&Task{}).
|
|
||||||
Where("id in (?)", ids).
|
|
||||||
Updates(params).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
type TaskQuotaUsage struct {
|
|
||||||
Mode string `json:"mode"`
|
|
||||||
Count float64 `json:"count"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
|
|
||||||
query := DB.Model(Task{})
|
|
||||||
// 添加过滤条件
|
|
||||||
if queryParams.ChannelID != "" {
|
|
||||||
query = query.Where("channel_id = ?", queryParams.ChannelID)
|
|
||||||
}
|
|
||||||
if queryParams.UserID != "" {
|
|
||||||
query = query.Where("user_id = ?", queryParams.UserID)
|
|
||||||
}
|
|
||||||
if len(queryParams.UserIDs) != 0 {
|
|
||||||
query = query.Where("user_id in (?)", queryParams.UserIDs)
|
|
||||||
}
|
|
||||||
if queryParams.TaskID != "" {
|
|
||||||
query = query.Where("task_id = ?", queryParams.TaskID)
|
|
||||||
}
|
|
||||||
if queryParams.Action != "" {
|
|
||||||
query = query.Where("action = ?", queryParams.Action)
|
|
||||||
}
|
|
||||||
if queryParams.Status != "" {
|
|
||||||
query = query.Where("status = ?", queryParams.Status)
|
|
||||||
}
|
|
||||||
if queryParams.StartTimestamp != 0 {
|
|
||||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
|
||||||
}
|
|
||||||
if queryParams.EndTimestamp != 0 {
|
|
||||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
|
||||||
}
|
|
||||||
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
|
|
||||||
return stat, err
|
|
||||||
}
|
|
||||||
@@ -5,14 +5,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Token struct {
|
type Token struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
UserId int `json:"user_id" gorm:"index"`
|
UserId int `json:"user_id"`
|
||||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index" `
|
Name string `json:"name" gorm:"index" `
|
||||||
@@ -250,9 +249,11 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
|
|||||||
if userQuota < quota {
|
if userQuota < quota {
|
||||||
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||||
}
|
}
|
||||||
err = DecreaseTokenQuota(tokenId, quota)
|
if !token.UnlimitedQuota {
|
||||||
if err != nil {
|
err = DecreaseTokenQuota(tokenId, quota)
|
||||||
return 0, err
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
err = DecreaseUserQuota(token.UserId, quota)
|
err = DecreaseUserQuota(token.UserId, quota)
|
||||||
return userQuota - quota, err
|
return userQuota - quota, err
|
||||||
@@ -270,13 +271,15 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if quota > 0 {
|
if !token.UnlimitedQuota {
|
||||||
err = DecreaseTokenQuota(tokenId, quota)
|
if quota > 0 {
|
||||||
} else {
|
err = DecreaseTokenQuota(tokenId, quota)
|
||||||
err = IncreaseTokenQuota(tokenId, -quota)
|
} else {
|
||||||
}
|
err = IncreaseTokenQuota(tokenId, -quota)
|
||||||
if err != nil {
|
}
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sendEmail {
|
if sendEmail {
|
||||||
@@ -294,7 +297,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
|
|||||||
prompt = "您的额度已用尽"
|
prompt = "您的额度已用尽"
|
||||||
}
|
}
|
||||||
if email != "" {
|
if email != "" {
|
||||||
topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress)
|
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
|
||||||
err = common.SendEmail(prompt, email,
|
err = common.SendEmail(prompt, email,
|
||||||
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,13 +1,21 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
type TopUp struct {
|
type TopUp struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
UserId int `json:"user_id" gorm:"index"`
|
UserId int `json:"user_id" gorm:"index"`
|
||||||
Amount int `json:"amount"`
|
Amount int `json:"amount"`
|
||||||
Money float64 `json:"money"`
|
Money float64 `json:"money"`
|
||||||
TradeNo string `json:"trade_no"`
|
TradeNo string `json:"trade_no" gorm:"unique"`
|
||||||
CreateTime int64 `json:"create_time"`
|
CreateTime int64 `json:"create_time"`
|
||||||
Status string `json:"status"`
|
CompleteTime int64 `json:"complete_time"`
|
||||||
|
Status string `json:"status"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (topUp *TopUp) Insert() error {
|
func (topUp *TopUp) Insert() error {
|
||||||
@@ -41,3 +49,51 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
|
|||||||
}
|
}
|
||||||
return topUp
|
return topUp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Recharge(referenceId string, customerId string) (err error) {
|
||||||
|
if referenceId == "" {
|
||||||
|
return errors.New("未提供支付单号")
|
||||||
|
}
|
||||||
|
|
||||||
|
var quota float64
|
||||||
|
topUp := &TopUp{}
|
||||||
|
|
||||||
|
refCol := "`trade_no`"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
refCol = `"trade_no"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("充值订单不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
if topUp.Status != common.TopUpStatusPending {
|
||||||
|
return errors.New("充值订单状态错误")
|
||||||
|
}
|
||||||
|
|
||||||
|
topUp.CompleteTime = common.GetTimestamp()
|
||||||
|
topUp.Status = common.TopUpStatusSuccess
|
||||||
|
err = tx.Save(topUp).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
quota = topUp.Money * common.QuotaPerUnit
|
||||||
|
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("充值失败," + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.LogQuotaF(quota), topUp.Amount))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ func logQuotaDataCache(userId int, username string, modelName string, quota int,
|
|||||||
if ok {
|
if ok {
|
||||||
quotaData.Count += 1
|
quotaData.Count += 1
|
||||||
quotaData.Quota += quota
|
quotaData.Quota += quota
|
||||||
quotaData.TokenUsed += tokenUsed
|
|
||||||
} else {
|
} else {
|
||||||
quotaData = &QuotaData{
|
quotaData = &QuotaData{
|
||||||
UserID: userId,
|
UserID: userId,
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ type User struct {
|
|||||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||||
|
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
|
||||||
|
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
|
||||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||||
@@ -35,6 +37,7 @@ type User struct {
|
|||||||
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
|
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
|
||||||
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
|
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
|
||||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||||
|
StripeCustomer string `json:"stripe_customer" gorm:"column:stripe_customer;index"`
|
||||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +67,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
|||||||
|
|
||||||
func GetMaxUserId() int {
|
func GetMaxUserId() int {
|
||||||
var user User
|
var user User
|
||||||
DB.Last(&user)
|
DB.Unscoped().Last(&user)
|
||||||
return user.Id
|
return user.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,6 +122,20 @@ func GetUserById(id int, selectAll bool) (*User, error) {
|
|||||||
return &user, err
|
return &user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetUserByIdUnscoped(id int, selectAll bool) (*User, error) {
|
||||||
|
if id == 0 {
|
||||||
|
return nil, errors.New("id 为空!")
|
||||||
|
}
|
||||||
|
user := User{Id: id}
|
||||||
|
var err error = nil
|
||||||
|
if selectAll {
|
||||||
|
err = DB.Unscoped().First(&user, "id = ?", id).Error
|
||||||
|
} else {
|
||||||
|
err = DB.Unscoped().Omit("password").First(&user, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
return &user, err
|
||||||
|
}
|
||||||
|
|
||||||
func GetUserIdByAffCode(affCode string) (int, error) {
|
func GetUserIdByAffCode(affCode string) (int, error) {
|
||||||
if affCode == "" {
|
if affCode == "" {
|
||||||
return 0, errors.New("affCode 为空!")
|
return 0, errors.New("affCode 为空!")
|
||||||
@@ -295,15 +312,10 @@ func (user *User) ValidateAndFill() (err error) {
|
|||||||
// that means if your field’s value is 0, '', false or other zero values,
|
// that means if your field’s value is 0, '', false or other zero values,
|
||||||
// it won’t be used to build query conditions
|
// it won’t be used to build query conditions
|
||||||
password := user.Password
|
password := user.Password
|
||||||
if (user.Username == "" && user.Email == "") || password == "" {
|
if user.Username == "" || password == "" {
|
||||||
return errors.New("用户名或密码为空")
|
return errors.New("用户名或密码为空")
|
||||||
}
|
}
|
||||||
// find buy username or email
|
DB.Where(User{Username: user.Username}).First(user)
|
||||||
if user.Username != "" {
|
|
||||||
DB.Where(User{Username: user.Username}).First(user)
|
|
||||||
} else if user.Email != "" {
|
|
||||||
DB.Where(User{Email: user.Email}).First(user)
|
|
||||||
}
|
|
||||||
okay := common.ValidatePasswordAndHash(password, user.Password)
|
okay := common.ValidatePasswordAndHash(password, user.Password)
|
||||||
if !okay || user.Status != common.UserStatusEnabled {
|
if !okay || user.Status != common.UserStatusEnabled {
|
||||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||||
@@ -335,6 +347,14 @@ func (user *User) FillUserByGitHubId() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *User) FillUserByLinuxDoId() error {
|
||||||
|
if user.LinuxDoId == "" {
|
||||||
|
return errors.New("LINUX DO id 为空!")
|
||||||
|
}
|
||||||
|
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (user *User) FillUserByWeChatId() error {
|
func (user *User) FillUserByWeChatId() error {
|
||||||
if user.WeChatId == "" {
|
if user.WeChatId == "" {
|
||||||
return errors.New("WeChat id 为空!")
|
return errors.New("WeChat id 为空!")
|
||||||
@@ -374,6 +394,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
|||||||
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
|
||||||
|
return DB.Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
|
||||||
|
}
|
||||||
|
|
||||||
func IsUsernameAlreadyTaken(username string) bool {
|
func IsUsernameAlreadyTaken(username string) bool {
|
||||||
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||||
}
|
}
|
||||||
@@ -419,6 +443,18 @@ func IsUserEnabled(userId int) (bool, error) {
|
|||||||
return user.Status == common.UserStatusEnabled, nil
|
return user.Status == common.UserStatusEnabled, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsLinuxDoEnabled(userId int) (bool, error) {
|
||||||
|
if userId == 0 {
|
||||||
|
return false, errors.New("user id is empty")
|
||||||
|
}
|
||||||
|
var user User
|
||||||
|
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ValidateAccessToken(token string) (user *User) {
|
func ValidateAccessToken(token string) (user *User) {
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -77,13 +75,3 @@ func batchUpdate() {
|
|||||||
}
|
}
|
||||||
common.SysLog("batch update finished")
|
common.SysLog("batch update finished")
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordExist(err error) (bool, error) {
|
|
||||||
if err == nil {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,32 +11,11 @@ import (
|
|||||||
type Adaptor interface {
|
type Adaptor interface {
|
||||||
// Init IsStream bool
|
// Init IsStream bool
|
||||||
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
|
||||||
InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest)
|
|
||||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||||
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
|
||||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
|
||||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
GetChannelName() string
|
GetChannelName() string
|
||||||
}
|
}
|
||||||
|
|
||||||
type TaskAdaptor interface {
|
|
||||||
Init(info *relaycommon.TaskRelayInfo)
|
|
||||||
|
|
||||||
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
|
|
||||||
|
|
||||||
BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
|
|
||||||
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
|
|
||||||
BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
|
|
||||||
|
|
||||||
DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
|
|
||||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
|
|
||||||
|
|
||||||
GetModelList() []string
|
|
||||||
GetChannelName() string
|
|
||||||
|
|
||||||
// FetchTask
|
|
||||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,13 +1,8 @@
|
|||||||
package ai360
|
package ai360
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"360gpt-turbo",
|
|
||||||
"360gpt-turbo-responsibility-8k",
|
|
||||||
"360gpt-pro",
|
|
||||||
"360GPT_S2_V9",
|
"360GPT_S2_V9",
|
||||||
"embedding-bert-512-v1",
|
"embedding-bert-512-v1",
|
||||||
"embedding_s1_v1",
|
"embedding_s1_v1",
|
||||||
"semantic_similarity_s1_v1",
|
"semantic_similarity_s1_v1",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "ai360"
|
|
||||||
|
|||||||
@@ -15,9 +15,6 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -56,10 +53,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,27 +50,3 @@ func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
|||||||
_ = c.Request.Body.Close()
|
_ = c.Request.Body.Close()
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
||||||
fullRequestURL, err := a.BuildRequestURL(info)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("new request failed: %w", err)
|
|
||||||
}
|
|
||||||
req.GetBody = func() (io.ReadCloser, error) {
|
|
||||||
return io.NopCloser(requestBody), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = a.BuildRequestHeader(c, req, info)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
|
||||||
}
|
|
||||||
resp, err := doRequest(c, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("do request failed: %w", err)
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -20,11 +20,6 @@ type Adaptor struct {
|
|||||||
RequestMode int
|
RequestMode int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
|
||||||
//TODO implement me
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||||
a.RequestMode = RequestModeMessage
|
a.RequestMode = RequestModeMessage
|
||||||
@@ -58,17 +53,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return claudeReq, err
|
return claudeReq, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
err, usage = awsStreamHandler(c, info, a.RequestMode)
|
||||||
} else {
|
} else {
|
||||||
err, usage = awsHandler(c, info, a.RequestMode)
|
err, usage = awsHandler(c, info, a.RequestMode)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ var awsModelIDMap = map[string]string{
|
|||||||
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
||||||
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "aws"
|
var ChannelName = "aws"
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ import (
|
|||||||
relaymodel "one-api/dto"
|
relaymodel "one-api/dto"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||||
@@ -113,7 +111,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
|||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
awsCli, err := newAwsClient(c, info)
|
awsCli, err := newAwsClient(c, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||||
@@ -158,20 +156,16 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
var usage relaymodel.Usage
|
var usage relaymodel.Usage
|
||||||
var id string
|
var id string
|
||||||
var model string
|
var model string
|
||||||
isFirst := true
|
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
event, ok := <-stream.Events()
|
event, ok := <-stream.Events()
|
||||||
if !ok {
|
if !ok {
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
switch v := event.(type) {
|
switch v := event.(type) {
|
||||||
case *types.ResponseStreamMemberChunk:
|
case *types.ResponseStreamMemberChunk:
|
||||||
if isFirst {
|
|
||||||
isFirst = false
|
|
||||||
info.FirstResponseTime = time.Now()
|
|
||||||
}
|
|
||||||
claudeResp := new(claude.ClaudeResponse)
|
claudeResp := new(claude.ClaudeResponse)
|
||||||
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -214,17 +208,6 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if info.ShouldIncludeUsage {
|
|
||||||
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
|
|
||||||
err := service.ObjectData(c, response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("send final response failed: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
service.Done(c)
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package baidu
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -10,80 +9,33 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
|
||||||
//TODO implement me
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
|
var fullRequestURL string
|
||||||
suffix := "chat/"
|
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
|
|
||||||
suffix = "embeddings/"
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
|
|
||||||
suffix = "embeddings/"
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
|
|
||||||
suffix = "embeddings/"
|
|
||||||
}
|
|
||||||
switch info.UpstreamModelName {
|
switch info.UpstreamModelName {
|
||||||
case "ERNIE-4.0":
|
|
||||||
suffix += "completions_pro"
|
|
||||||
case "ERNIE-Bot-4":
|
case "ERNIE-Bot-4":
|
||||||
suffix += "completions_pro"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||||
case "ERNIE-Bot":
|
|
||||||
suffix += "completions"
|
|
||||||
case "ERNIE-Bot-turbo":
|
|
||||||
suffix += "eb-instant"
|
|
||||||
case "ERNIE-Speed":
|
|
||||||
suffix += "ernie_speed"
|
|
||||||
case "ERNIE-4.0-8K":
|
|
||||||
suffix += "completions_pro"
|
|
||||||
case "ERNIE-3.5-8K":
|
|
||||||
suffix += "completions"
|
|
||||||
case "ERNIE-3.5-8K-0205":
|
|
||||||
suffix += "ernie-3.5-8k-0205"
|
|
||||||
case "ERNIE-3.5-8K-1222":
|
|
||||||
suffix += "ernie-3.5-8k-1222"
|
|
||||||
case "ERNIE-Bot-8K":
|
case "ERNIE-Bot-8K":
|
||||||
suffix += "ernie_bot_8k"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
|
||||||
case "ERNIE-3.5-4K-0205":
|
case "ERNIE-Bot":
|
||||||
suffix += "ernie-3.5-4k-0205"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||||
case "ERNIE-Speed-8K":
|
case "ERNIE-Speed":
|
||||||
suffix += "ernie_speed"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
|
||||||
case "ERNIE-Speed-128K":
|
case "ERNIE-Bot-turbo":
|
||||||
suffix += "ernie-speed-128k"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||||
case "ERNIE-Lite-8K-0922":
|
|
||||||
suffix += "eb-instant"
|
|
||||||
case "ERNIE-Lite-8K-0308":
|
|
||||||
suffix += "ernie-lite-8k"
|
|
||||||
case "ERNIE-Tiny-8K":
|
|
||||||
suffix += "ernie-tiny-8k"
|
|
||||||
case "BLOOMZ-7B":
|
case "BLOOMZ-7B":
|
||||||
suffix += "bloomz_7b1"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||||
case "Embedding-V1":
|
case "Embedding-V1":
|
||||||
suffix += "embedding-v1"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
|
||||||
case "bge-large-zh":
|
|
||||||
suffix += "bge_large_zh"
|
|
||||||
case "bge-large-en":
|
|
||||||
suffix += "bge_large_en"
|
|
||||||
case "tao-8k":
|
|
||||||
suffix += "tao_8k"
|
|
||||||
default:
|
|
||||||
suffix += strings.ToLower(info.UpstreamModelName)
|
|
||||||
}
|
}
|
||||||
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
|
|
||||||
var accessToken string
|
var accessToken string
|
||||||
var err error
|
var err error
|
||||||
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
||||||
@@ -113,10 +65,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,22 +1,12 @@
|
|||||||
package baidu
|
package baidu
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"ERNIE-4.0-8K",
|
"ERNIE-Bot-4",
|
||||||
"ERNIE-3.5-8K",
|
|
||||||
"ERNIE-3.5-8K-0205",
|
|
||||||
"ERNIE-3.5-8K-1222",
|
|
||||||
"ERNIE-Bot-8K",
|
"ERNIE-Bot-8K",
|
||||||
"ERNIE-3.5-4K-0205",
|
"ERNIE-Bot",
|
||||||
"ERNIE-Speed-8K",
|
"ERNIE-Speed",
|
||||||
"ERNIE-Speed-128K",
|
"ERNIE-Bot-turbo",
|
||||||
"ERNIE-Lite-8K-0922",
|
|
||||||
"ERNIE-Lite-8K-0308",
|
|
||||||
"ERNIE-Tiny-8K",
|
|
||||||
"BLOOMZ-7B",
|
|
||||||
"Embedding-V1",
|
"Embedding-V1",
|
||||||
"bge-large-zh",
|
|
||||||
"bge-large-en",
|
|
||||||
"tao-8k",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "baidu"
|
var ChannelName = "baidu"
|
||||||
|
|||||||
@@ -11,16 +11,9 @@ type BaiduMessage struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type BaiduChatRequest struct {
|
type BaiduChatRequest struct {
|
||||||
Messages []BaiduMessage `json:"messages"`
|
Messages []BaiduMessage `json:"messages"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Stream bool `json:"stream"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
UserId string `json:"user_id,omitempty"`
|
||||||
PenaltyScore float64 `json:"penalty_score,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
System string `json:"system,omitempty"`
|
|
||||||
DisableSearch bool `json:"disable_search,omitempty"`
|
|
||||||
EnableCitation bool `json:"enable_citation,omitempty"`
|
|
||||||
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
|
|
||||||
UserId string `json:"user_id,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
|
|||||||
@@ -22,27 +22,17 @@ import (
|
|||||||
var baiduTokenStore sync.Map
|
var baiduTokenStore sync.Map
|
||||||
|
|
||||||
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||||
baiduRequest := BaiduChatRequest{
|
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||||
Temperature: request.Temperature,
|
|
||||||
TopP: request.TopP,
|
|
||||||
PenaltyScore: request.FrequencyPenalty,
|
|
||||||
Stream: request.Stream,
|
|
||||||
DisableSearch: false,
|
|
||||||
EnableCitation: false,
|
|
||||||
MaxOutputTokens: int(request.MaxTokens),
|
|
||||||
UserId: request.User,
|
|
||||||
}
|
|
||||||
for _, message := range request.Messages {
|
for _, message := range request.Messages {
|
||||||
if message.Role == "system" {
|
messages = append(messages, BaiduMessage{
|
||||||
baiduRequest.System = message.StringContent()
|
Role: message.Role,
|
||||||
} else {
|
Content: message.StringContent(),
|
||||||
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
|
})
|
||||||
Role: message.Role,
|
}
|
||||||
Content: message.StringContent(),
|
return &BaiduChatRequest{
|
||||||
})
|
Messages: messages,
|
||||||
}
|
Stream: request.Stream,
|
||||||
}
|
}
|
||||||
return &baiduRequest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||||
|
|||||||
@@ -21,11 +21,6 @@ type Adaptor struct {
|
|||||||
RequestMode int
|
RequestMode int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
|
||||||
//TODO implement me
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||||
a.RequestMode = RequestModeMessage
|
a.RequestMode = RequestModeMessage
|
||||||
@@ -64,17 +59,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
|
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ var ModelList = []string{
|
|||||||
"claude-3-sonnet-20240229",
|
"claude-3-sonnet-20240229",
|
||||||
"claude-3-opus-20240229",
|
"claude-3-opus-20240229",
|
||||||
"claude-3-haiku-20240307",
|
"claude-3-haiku-20240307",
|
||||||
"claude-3-5-sonnet-20240620",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "claude"
|
var ChannelName = "claude"
|
||||||
|
|||||||
@@ -8,12 +8,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func stopReasonClaude2OpenAI(reason string) string {
|
func stopReasonClaude2OpenAI(reason string) string {
|
||||||
@@ -141,11 +138,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
|||||||
// 判断是否是url
|
// 判断是否是url
|
||||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||||
// 是url,获取图片的类型和base64编码的数据
|
// 是url,获取图片的类型和base64编码的数据
|
||||||
mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
|
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
|
||||||
claudeMediaMessage.Source.MediaType = mimeType
|
claudeMediaMessage.Source.MediaType = mimeType
|
||||||
claudeMediaMessage.Source.Data = data
|
claudeMediaMessage.Source.Data = data
|
||||||
} else {
|
} else {
|
||||||
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
|
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -249,7 +246,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
|||||||
return &fullTextResponse
|
return &fullTextResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
var usage *dto.Usage
|
var usage *dto.Usage
|
||||||
usage = &dto.Usage{}
|
usage = &dto.Usage{}
|
||||||
@@ -268,8 +265,8 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
return 0, nil, nil
|
return 0, nil, nil
|
||||||
})
|
})
|
||||||
dataChan := make(chan string, 5)
|
dataChan := make(chan string)
|
||||||
stopChan := make(chan bool, 2)
|
stopChan := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
@@ -277,23 +274,14 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
data = strings.TrimPrefix(data, "data: ")
|
data = strings.TrimPrefix(data, "data: ")
|
||||||
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
dataChan <- data
|
||||||
// send data timeout, stop the stream
|
|
||||||
common.LogError(c, "send data timeout, stop the stream")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
isFirst := true
|
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
if isFirst {
|
|
||||||
isFirst = false
|
|
||||||
info.FirstResponseTime = time.Now()
|
|
||||||
}
|
|
||||||
// some implementations may add \r at the end of data
|
// some implementations may add \r at the end of data
|
||||||
data = strings.TrimSuffix(data, "\r")
|
data = strings.TrimSuffix(data, "\r")
|
||||||
var claudeResponse ClaudeResponse
|
var claudeResponse ClaudeResponse
|
||||||
@@ -314,7 +302,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
if claudeResponse.Type == "message_start" {
|
if claudeResponse.Type == "message_start" {
|
||||||
// message_start, 获取usage
|
// message_start, 获取usage
|
||||||
responseId = claudeResponse.Message.Id
|
responseId = claudeResponse.Message.Id
|
||||||
info.UpstreamModelName = claudeResponse.Message.Model
|
modelName = claudeResponse.Message.Model
|
||||||
usage.PromptTokens = claudeUsage.InputTokens
|
usage.PromptTokens = claudeUsage.InputTokens
|
||||||
} else if claudeResponse.Type == "content_block_delta" {
|
} else if claudeResponse.Type == "content_block_delta" {
|
||||||
responseText += claudeResponse.Delta.Text
|
responseText += claudeResponse.Delta.Text
|
||||||
@@ -328,38 +316,30 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
//response.Id = responseId
|
//response.Id = responseId
|
||||||
response.Id = responseId
|
response.Id = responseId
|
||||||
response.Created = createdTime
|
response.Created = createdTime
|
||||||
response.Model = info.UpstreamModelName
|
response.Model = modelName
|
||||||
|
|
||||||
err = service.ObjectData(c, response)
|
jsonStr, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(err.Error())
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||||
return true
|
return true
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
if requestMode == RequestModeCompletion {
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
} else {
|
|
||||||
if usage.PromptTokens == 0 {
|
|
||||||
usage.PromptTokens = info.PromptTokens
|
|
||||||
}
|
|
||||||
if usage.CompletionTokens == 0 {
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if info.ShouldIncludeUsage {
|
|
||||||
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
|
|
||||||
err := service.ObjectData(c, response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("send final response failed: " + err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
service.Done(c)
|
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if requestMode == RequestModeCompletion {
|
||||||
|
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
||||||
|
} else {
|
||||||
|
if usage.CompletionTokens == 0 {
|
||||||
|
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
@@ -390,7 +370,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model)
|
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,24 +8,16 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
@@ -42,19 +34,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return requestConvertRerank2Cohere(request), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.IsStream {
|
||||||
err, usage = cohereRerankHandler(c, resp, info)
|
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
if info.IsStream {
|
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||||
err, usage = cohereStreamHandler(c, resp, info)
|
|
||||||
} else {
|
|
||||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package cohere
|
|||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
||||||
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "cohere"
|
var ChannelName = "cohere"
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
package cohere
|
package cohere
|
||||||
|
|
||||||
import "one-api/dto"
|
|
||||||
|
|
||||||
type CohereRequest struct {
|
type CohereRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
ChatHistory []ChatHistory `json:"chat_history"`
|
ChatHistory []ChatHistory `json:"chat_history"`
|
||||||
@@ -30,19 +28,6 @@ type CohereResponseResult struct {
|
|||||||
Meta CohereMeta `json:"meta"`
|
Meta CohereMeta `json:"meta"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CohereRerankRequest struct {
|
|
||||||
Documents []any `json:"documents"`
|
|
||||||
Query string `json:"query"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
TopN int `json:"top_n"`
|
|
||||||
ReturnDocuments bool `json:"return_documents"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CohereRerankResponseResult struct {
|
|
||||||
Results []dto.RerankResponseDocument `json:"results"`
|
|
||||||
Meta CohereMeta `json:"meta"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CohereMeta struct {
|
type CohereMeta struct {
|
||||||
//Tokens CohereTokens `json:"tokens"`
|
//Tokens CohereTokens `json:"tokens"`
|
||||||
BilledUnits CohereBilledUnits `json:"billed_units"`
|
BilledUnits CohereBilledUnits `json:"billed_units"`
|
||||||
|
|||||||
@@ -9,10 +9,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||||
@@ -47,20 +45,6 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
|||||||
return &cohereReq
|
return &cohereReq
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
|
|
||||||
if rerankRequest.TopN == 0 {
|
|
||||||
rerankRequest.TopN = 1
|
|
||||||
}
|
|
||||||
cohereReq := CohereRerankRequest{
|
|
||||||
Query: rerankRequest.Query,
|
|
||||||
Documents: rerankRequest.Documents,
|
|
||||||
Model: rerankRequest.Model,
|
|
||||||
TopN: rerankRequest.TopN,
|
|
||||||
ReturnDocuments: true,
|
|
||||||
}
|
|
||||||
return &cohereReq
|
|
||||||
}
|
|
||||||
|
|
||||||
func stopReasonCohere2OpenAI(reason string) string {
|
func stopReasonCohere2OpenAI(reason string) string {
|
||||||
switch reason {
|
switch reason {
|
||||||
case "COMPLETE":
|
case "COMPLETE":
|
||||||
@@ -72,7 +56,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
@@ -100,14 +84,9 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
isFirst := true
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
if isFirst {
|
|
||||||
isFirst = false
|
|
||||||
info.FirstResponseTime = time.Now()
|
|
||||||
}
|
|
||||||
data = strings.TrimSuffix(data, "\r")
|
data = strings.TrimSuffix(data, "\r")
|
||||||
var cohereResp CohereResponse
|
var cohereResp CohereResponse
|
||||||
err := json.Unmarshal([]byte(data), &cohereResp)
|
err := json.Unmarshal([]byte(data), &cohereResp)
|
||||||
@@ -119,7 +98,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
openaiResp.Id = responseId
|
openaiResp.Id = responseId
|
||||||
openaiResp.Created = createdTime
|
openaiResp.Created = createdTime
|
||||||
openaiResp.Object = "chat.completion.chunk"
|
openaiResp.Object = "chat.completion.chunk"
|
||||||
openaiResp.Model = info.UpstreamModelName
|
openaiResp.Model = modelName
|
||||||
if cohereResp.IsFinished {
|
if cohereResp.IsFinished {
|
||||||
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
||||||
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
||||||
@@ -158,7 +137,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
if usage.PromptTokens == 0 {
|
if usage.PromptTokens == 0 {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
@@ -208,42 +187,3 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
|||||||
_, err = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var cohereResp CohereRerankResponseResult
|
|
||||||
err = json.Unmarshal(responseBody, &cohereResp)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
usage := dto.Usage{}
|
|
||||||
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
|
||||||
usage.PromptTokens = info.PromptTokens
|
|
||||||
usage.CompletionTokens = 0
|
|
||||||
usage.TotalTokens = info.PromptTokens
|
|
||||||
} else {
|
|
||||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
|
||||||
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
|
||||||
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
var rerankResp dto.RerankResponse
|
|
||||||
rerankResp.Results = cohereResp.Results
|
|
||||||
rerankResp.Usage = usage
|
|
||||||
|
|
||||||
jsonResponse, err := json.Marshal(rerankResp)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &usage
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
package dify
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/dto"
|
|
||||||
"one-api/relay/channel"
|
|
||||||
relaycommon "one-api/relay/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Adaptor struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
|
||||||
//TODO implement me
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
||||||
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
|
||||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
|
||||||
if request == nil {
|
|
||||||
return nil, errors.New("request is nil")
|
|
||||||
}
|
|
||||||
return requestOpenAI2Dify(*request), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
|
||||||
if info.IsStream {
|
|
||||||
err, usage = difyStreamHandler(c, resp, info)
|
|
||||||
} else {
|
|
||||||
err, usage = difyHandler(c, resp, info)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetModelList() []string {
|
|
||||||
return ModelList
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetChannelName() string {
|
|
||||||
return ChannelName
|
|
||||||
}
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package dify
|
|
||||||
|
|
||||||
var ModelList []string
|
|
||||||
|
|
||||||
var ChannelName = "dify"
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
package dify
|
|
||||||
|
|
||||||
import "one-api/dto"
|
|
||||||
|
|
||||||
type DifyChatRequest struct {
|
|
||||||
Inputs map[string]interface{} `json:"inputs"`
|
|
||||||
Query string `json:"query"`
|
|
||||||
ResponseMode string `json:"response_mode"`
|
|
||||||
User string `json:"user"`
|
|
||||||
AutoGenerateName bool `json:"auto_generate_name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DifyMetaData struct {
|
|
||||||
Usage dto.Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DifyData struct {
|
|
||||||
WorkflowId string `json:"workflow_id"`
|
|
||||||
NodeId string `json:"node_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DifyChatCompletionResponse struct {
|
|
||||||
ConversationId string `json:"conversation_id"`
|
|
||||||
Answers string `json:"answers"`
|
|
||||||
CreateAt int64 `json:"create_at"`
|
|
||||||
MetaData DifyMetaData `json:"metadata"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DifyChunkChatCompletionResponse struct {
|
|
||||||
Event string `json:"event"`
|
|
||||||
ConversationId string `json:"conversation_id"`
|
|
||||||
Answer string `json:"answer"`
|
|
||||||
Data DifyData `json:"data"`
|
|
||||||
MetaData DifyMetaData `json:"metadata"`
|
|
||||||
}
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
package dify
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/constant"
|
|
||||||
"one-api/dto"
|
|
||||||
relaycommon "one-api/relay/common"
|
|
||||||
"one-api/service"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
|
|
||||||
content := ""
|
|
||||||
for _, message := range request.Messages {
|
|
||||||
if message.Role == "system" {
|
|
||||||
content += "SYSTEM: \n" + message.StringContent() + "\n"
|
|
||||||
} else if message.Role == "assistant" {
|
|
||||||
content += "ASSISTANT: \n" + message.StringContent() + "\n"
|
|
||||||
} else {
|
|
||||||
content += "USER: \n" + message.StringContent() + "\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mode := "blocking"
|
|
||||||
if request.Stream {
|
|
||||||
mode = "streaming"
|
|
||||||
}
|
|
||||||
user := request.User
|
|
||||||
if user == "" {
|
|
||||||
user = "api-user"
|
|
||||||
}
|
|
||||||
return &DifyChatRequest{
|
|
||||||
Inputs: make(map[string]interface{}),
|
|
||||||
Query: content,
|
|
||||||
ResponseMode: mode,
|
|
||||||
User: user,
|
|
||||||
AutoGenerateName: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
|
|
||||||
response := dto.ChatCompletionsStreamResponse{
|
|
||||||
Object: "chat.completion.chunk",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Model: "dify",
|
|
||||||
}
|
|
||||||
var choice dto.ChatCompletionsStreamResponseChoice
|
|
||||||
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
|
|
||||||
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
|
|
||||||
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
|
|
||||||
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
|
|
||||||
} else if difyResponse.Event == "message" {
|
|
||||||
choice.Delta.SetContentString(difyResponse.Answer)
|
|
||||||
}
|
|
||||||
response.Choices = append(response.Choices, choice)
|
|
||||||
return &response
|
|
||||||
}
|
|
||||||
|
|
||||||
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
||||||
var responseText string
|
|
||||||
usage := &dto.Usage{}
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(bufio.ScanLines)
|
|
||||||
|
|
||||||
service.SetEventStreamHeaders(c)
|
|
||||||
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = strings.TrimPrefix(data, "data:")
|
|
||||||
var difyResponse DifyChunkChatCompletionResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &difyResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var openaiResponse dto.ChatCompletionsStreamResponse
|
|
||||||
if difyResponse.Event == "message_end" {
|
|
||||||
usage = &difyResponse.MetaData.Usage
|
|
||||||
break
|
|
||||||
} else if difyResponse.Event == "error" {
|
|
||||||
break
|
|
||||||
} else {
|
|
||||||
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
|
|
||||||
if len(openaiResponse.Choices) != 0 {
|
|
||||||
responseText += openaiResponse.Choices[0].Delta.GetContentString()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = service.ObjectData(c, openaiResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError(err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
common.SysError("error reading stream: " + err.Error())
|
|
||||||
}
|
|
||||||
service.Done(c)
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
common.SysError("close_response_body_failed: " + err.Error())
|
|
||||||
}
|
|
||||||
if usage.TotalTokens == 0 {
|
|
||||||
usage.PromptTokens = info.PromptTokens
|
|
||||||
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
|
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
}
|
|
||||||
return nil, usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
||||||
var difyResponse DifyChatCompletionResponse
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &difyResponse)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
|
||||||
Id: difyResponse.ConversationId,
|
|
||||||
Object: "chat.completion",
|
|
||||||
Created: common.GetTimestamp(),
|
|
||||||
Usage: difyResponse.MetaData.Usage,
|
|
||||||
}
|
|
||||||
content, _ := json.Marshal(difyResponse.Answers)
|
|
||||||
choice := dto.OpenAITextResponseChoice{
|
|
||||||
Index: 0,
|
|
||||||
Message: dto.Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: content,
|
|
||||||
},
|
|
||||||
FinishReason: "stop",
|
|
||||||
}
|
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &difyResponse.MetaData.Usage
|
|
||||||
}
|
|
||||||
@@ -15,35 +15,31 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 定义一个映射,存储模型名称和对应的版本
|
// 定义一个映射,存储模型名称和对应的版本
|
||||||
var modelVersionMap = map[string]string{
|
var modelVersionMap = map[string]string{
|
||||||
"gemini-1.5-pro-latest": "v1beta",
|
"gemini-1.5-pro-latest": "v1beta",
|
||||||
"gemini-1.5-flash-latest": "v1beta",
|
"gemini-ultra": "v1beta",
|
||||||
"gemini-ultra": "v1beta",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
||||||
version, beta := modelVersionMap[info.UpstreamModelName]
|
version, beta := modelVersionMap[info.UpstreamModelName]
|
||||||
if !beta {
|
if !beta {
|
||||||
if info.ApiVersion != "" {
|
if info.ApiVersion != "" {
|
||||||
version = info.ApiVersion
|
version = info.ApiVersion
|
||||||
} else {
|
} else {
|
||||||
version = "v1"
|
version = "v1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
action = "streamGenerateContent"
|
action = "streamGenerateContent"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
@@ -59,10 +55,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return CovertGemini2OpenAI(*request), nil
|
return CovertGemini2OpenAI(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
@@ -70,7 +62,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = geminiChatStreamHandler(c, resp, info)
|
err, responseText = geminiChatStreamHandler(c, resp)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
|
"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-ultra",
|
||||||
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001",
|
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -76,7 +74,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
|||||||
if imageNum > GeminiVisionMaxImageNum {
|
if imageNum > GeminiVisionMaxImageNum {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, GeminiPart{
|
||||||
InlineData: &GeminiInlineData{
|
InlineData: &GeminiInlineData{
|
||||||
MimeType: mimeType,
|
MimeType: mimeType,
|
||||||
@@ -162,10 +160,10 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
|
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
dataChan := make(chan string, 5)
|
dataChan := make(chan string)
|
||||||
stopChan := make(chan bool, 2)
|
stopChan := make(chan bool)
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
if atEOF && len(data) == 0 {
|
if atEOF && len(data) == 0 {
|
||||||
@@ -188,23 +186,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
}
|
}
|
||||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||||
data = strings.TrimSuffix(data, "\"")
|
data = strings.TrimSuffix(data, "\"")
|
||||||
if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
dataChan <- data
|
||||||
// send data timeout, stop the stream
|
|
||||||
common.LogError(c, "send data timeout, stop the stream")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
isFirst := true
|
|
||||||
service.SetEventStreamHeaders(c)
|
service.SetEventStreamHeaders(c)
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
if isFirst {
|
|
||||||
isFirst = false
|
|
||||||
info.FirstResponseTime = time.Now()
|
|
||||||
}
|
|
||||||
// this is used to prevent annoying \ related format bug
|
// this is used to prevent annoying \ related format bug
|
||||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||||
type dummyStruct struct {
|
type dummyStruct struct {
|
||||||
@@ -267,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||||
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
|||||||
@@ -1,64 +0,0 @@
|
|||||||
package jina
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/dto"
|
|
||||||
"one-api/relay/channel"
|
|
||||||
relaycommon "one-api/relay/common"
|
|
||||||
"one-api/relay/constant"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Adaptor struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
|
||||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
|
||||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
|
||||||
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
|
|
||||||
}
|
|
||||||
return "", errors.New("invalid relay mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
|
||||||
channel.SetupApiRequestHeader(info, c, req)
|
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
|
||||||
return request, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return request, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
|
||||||
err, usage = jinaRerankHandler(c, resp)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetModelList() []string {
|
|
||||||
return ModelList
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) GetChannelName() string {
|
|
||||||
return ChannelName
|
|
||||||
}
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
package jina
|
|
||||||
|
|
||||||
var ModelList = []string{
|
|
||||||
"jina-clip-v1",
|
|
||||||
"jina-reranker-v2-base-multilingual",
|
|
||||||
}
|
|
||||||
|
|
||||||
var ChannelName = "jina"
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
package jina
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/dto"
|
|
||||||
"one-api/service"
|
|
||||||
)
|
|
||||||
|
|
||||||
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
var jinaResp dto.RerankResponse
|
|
||||||
err = json.Unmarshal(responseBody, &jinaResp)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonResponse, err := json.Marshal(jinaResp)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &jinaResp.Usage
|
|
||||||
}
|
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
package lingyiwanwu
|
package lingyiwanwu
|
||||||
|
|
||||||
// https://platform.lingyiwanwu.com/docs
|
// https://platform.lingyiwanwu.com/docs
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
"yi-large", "yi-medium", "yi-vision", "yi-medium-200k", "yi-spark", "yi-large-rag", "yi-large-turbo", "yi-large-preview", "yi-large-rag-preview",
|
"yi-34b-chat-0205",
|
||||||
}
|
"yi-34b-chat-200k",
|
||||||
|
"yi-vl-plus",
|
||||||
var ChannelName = "lingyiwanwu"
|
}
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
package minimax
|
|
||||||
|
|
||||||
// https://www.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd
|
|
||||||
|
|
||||||
var ModelList = []string{
|
|
||||||
"abab6.5-chat",
|
|
||||||
"abab6.5s-chat",
|
|
||||||
"abab6-chat",
|
|
||||||
"abab5.5-chat",
|
|
||||||
"abab5.5s-chat",
|
|
||||||
}
|
|
||||||
|
|
||||||
var ChannelName = "minimax"
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package minimax
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
relaycommon "one-api/relay/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
||||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil
|
|
||||||
}
|
|
||||||
@@ -5,5 +5,3 @@ var ModelList = []string{
|
|||||||
"moonshot-v1-32k",
|
"moonshot-v1-32k",
|
||||||
"moonshot-v1-128k",
|
"moonshot-v1-128k",
|
||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "moonshot"
|
|
||||||
|
|||||||
@@ -16,9 +16,6 @@ import (
|
|||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,10 +45,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
@@ -59,10 +52,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info)
|
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||||
if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
package ollama
|
package ollama
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList []string
|
||||||
"llama3-7b",
|
|
||||||
}
|
|
||||||
|
|
||||||
var ChannelName = "ollama"
|
var ChannelName = "ollama"
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user