Compare commits

..

59 Commits

Author SHA1 Message Date
wozulong
1c371300ab merge upstream
Signed-off-by: wozulong <>
2024-08-06 15:54:28 +08:00
wozulong
918690701d merge upstream
Signed-off-by: wozulong <>
2024-08-01 15:12:54 +08:00
wozulong
f0008e95fa merge upstream
Signed-off-by: wozulong <>
2024-07-22 10:51:49 +08:00
wozulong
7a249b206d merge upstream
Signed-off-by: wozulong <>
2024-07-19 10:58:21 +08:00
wozulong
0cc7f5cca6 merge upstream
Signed-off-by: wozulong <>
2024-07-11 14:10:10 +08:00
wozulong
ed86ec8b59 Merge remote-tracking branch 'upstream/main' 2024-07-05 11:22:44 +08:00
wozulong
895ee09b33 merge upstream
Signed-off-by: wozulong <>
2024-07-01 15:14:22 +08:00
wozulong
61006bed9e merge upstream
Signed-off-by: wozulong <>
2024-06-25 18:11:23 +08:00
wozulong
58591b8c2a fix linuxdo oauth
Signed-off-by: wozulong <>
2024-05-31 13:12:35 +08:00
wozulong
69b9950aa0 merge upstream
Signed-off-by: wozulong <>
2024-05-31 11:14:25 +08:00
wozulong
bcf1b160d1 update
Signed-off-by: wozulong <>
2024-05-28 12:09:19 +08:00
wozulong
1e0e40aa69 update
Signed-off-by: wozulong <>
2024-05-28 12:04:10 +08:00
wozulong
74392efbed merge upstream
Signed-off-by: wozulong <>
2024-05-28 11:57:05 +08:00
wozulong
cc020d6a40 fix req model
Signed-off-by: wozulong <>
2024-05-24 16:46:11 +08:00
wozulong
daa1741aed merge upstream
Signed-off-by: wozulong <>
2024-05-22 11:38:59 +08:00
wozulong
a25bcaa58f fix the type of the logprobs
Signed-off-by: wozulong <>
2024-05-22 11:30:38 +08:00
wozulong
d34b601dae Merge remote-tracking branch 'upstream/main' 2024-05-19 16:03:53 +08:00
wozulong
9932962320 merge upstream
Signed-off-by: wozulong <>
2024-05-16 19:00:58 +08:00
wozulong
00d6cda9ed refine completions req type
Signed-off-by: wozulong <>
2024-05-16 18:54:21 +08:00
wozulong
5ffb520363 merge upstream
Signed-off-by: wozulong <>
2024-05-15 12:36:29 +08:00
wozulong
e0f80cdb8f add gpt-4o
Signed-off-by: wozulong <>
2024-05-15 11:49:35 +08:00
wozulong
7a7a923504 add gpt-4o tokenizer
Signed-off-by: wozulong <>
2024-05-14 14:47:57 +08:00
wozulong
4f6c171a08 add gpt-4o model ratio
Signed-off-by: wozulong <>
2024-05-14 02:38:44 +08:00
wozulong
310f8c247e fix auth
Signed-off-by: wozulong <>
2024-05-11 13:47:51 +08:00
wozulong
811019bf5c lint fix
Signed-off-by: wozulong <>
2024-05-10 14:15:38 +08:00
wozulong
0ed6600437 Merge remote-tracking branch 'upstream/main' 2024-05-10 14:08:43 +08:00
wozulong
1fe7f14d57 merge upstream
Signed-off-by: wozulong <>
2024-04-28 14:04:39 +08:00
wozulong
a7bafec1bf merge upstream
Signed-off-by: wozulong <>
2024-04-28 14:04:19 +08:00
wozulong
ed951b3974 merge upstream
Signed-off-by: wozulong <>
2024-04-25 16:01:18 +08:00
wozulong
c74e43b8fd merge upstream
Signed-off-by: wozulong <>
2024-04-18 15:18:28 +08:00
wozulong
03bd9b0cc4 merge upstream
Signed-off-by: wozulong <>
2024-04-10 11:31:20 +08:00
wozulong
149902bd8a add gpt-4-turbo
Signed-off-by: wozulong <>
2024-04-10 11:21:12 +08:00
wozulong
cd3ed22045 Merge remote-tracking branch 'upstream/main' 2024-04-06 20:24:20 +08:00
wozulong
2329d387ca fix icon
Signed-off-by: wozulong <>
2024-04-06 20:24:07 +08:00
wozulong
7d18a8e2a9 merge upstream
Signed-off-by: wozulong <>
2024-04-05 22:10:07 +08:00
wozulong
80af3718d0 lint fix
Signed-off-by: wozulong <>
2024-04-01 10:16:35 +08:00
wozulong
77ea6bec46 Merge remote-tracking branch 'upstream/main' 2024-04-01 10:15:48 +08:00
wozulong
c0ab8ae953 update github action
Signed-off-by: wozulong <>
2024-03-27 18:04:27 +08:00
wozulong
923c2dee32 merge upstream
Signed-off-by: wozulong <>
2024-03-27 18:01:43 +08:00
wozulong
ea17a46d8e fixed pagination issue in the log list
Signed-off-by: wozulong <>
2024-03-25 16:36:57 +08:00
wozulong
bfe9e5d25a Merge remote-tracking branch 'upstream/main' 2024-03-25 15:49:12 +08:00
wozulong
831ff47254 Merge remote-tracking branch 'upstream/main' 2024-03-24 00:18:02 +08:00
wozulong
11eaba6b5d fix gpt-3.5-turbo points
Signed-off-by: wozulong <>
2024-03-24 00:17:49 +08:00
wozulong
c2e4ec25c8 merge upstream
Signed-off-by: wozulong <>
2024-03-23 22:51:18 +08:00
wozulong
8537f10412 save stripe custome id
Signed-off-by: wozulong <>
2024-03-23 21:59:15 +08:00
wozulong
71d60eeef7 merge upstream
Signed-off-by: wozulong <>
2024-03-22 18:09:13 +08:00
wozulong
247ae0988f replace epay with stripe
Signed-off-by: wozulong <>
2024-03-22 18:00:20 +08:00
wozulong
0907fa6994 fix next uid
Signed-off-by: wozulong <>
2024-03-21 15:02:47 +08:00
wozulong
9855343aa8 Merge remote-tracking branch 'upstream/main' 2024-03-21 15:01:51 +08:00
wozulong
bd50fde268 Merge remote-tracking branch 'upstream/main' 2024-03-21 13:10:21 +08:00
wozulong
4267de5642 fix yarn timeout
Signed-off-by: wozulong <>
2024-03-20 19:14:17 +08:00
wozulong
8b55116563 update build files
Signed-off-by: wozulong <>
2024-03-20 18:31:22 +08:00
wozulong
f35e63e3f3 limit 'LINUX DO' trust level now available
Signed-off-by: wozulong <>
2024-03-20 16:54:38 +08:00
wozulong
17c409de23 update gh action
Signed-off-by: wozulong <>
2024-03-20 14:11:21 +08:00
wozulong
e4753e7411 synced with upstream
Signed-off-by: wozulong <>
2024-03-20 13:52:10 +08:00
paderlol
9adefa80b9 Optimizing Docker image builds (#1)
Co-authored-by: pader.zhang <pader.zhang@starlight-sms.com>
2024-03-14 23:10:05 -07:00
wozulong
4ce2381182 update issue template
Signed-off-by: wozulong <>
2024-03-14 20:02:22 +08:00
我秦始皇
62afc21ea5 Merge branch 'Calcium-Ion:main' into main 2024-03-14 03:56:31 -07:00
wozulong
7ddb7c586d 1. add LINUX DO oauth
2. fix oauth reg aff issue

Signed-off-by: wozulong <>
2024-03-14 18:53:54 +08:00
148 changed files with 8658 additions and 7979 deletions

12
.github/FUNDING.yml vendored
View File

@@ -1,12 +0,0 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -1,5 +1,5 @@
blank_issues_enabled: false blank_issues_enabled: false
contact_links: contact_links:
- name: 项目群聊 - name: 交流社区
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg url: https://linux.do
about: QQ 群629454374 about: 项目交流社区

View File

@@ -1,9 +1,6 @@
name: Publish Docker image (amd64) name: Publish Docker image (amd64)
on: on:
push:
tags:
- '*'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
name: name:
@@ -42,7 +39,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
calciumion/new-api pengzhile/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images

View File

@@ -49,7 +49,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
calciumion/new-api pengzhile/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images

View File

@@ -7,7 +7,7 @@ all: build-frontend start-backend
build-frontend: build-frontend:
@echo "Building frontend..." @echo "Building frontend..."
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build @cd $(FRONTEND_DIR) && yarn install --network-timeout 1000000 && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) yarn build
start-backend: start-backend:
@echo "Starting backend dev server..." @echo "Starting backend dev server..."

View File

@@ -1,13 +1,6 @@
<div align="center">
![new-api](/web/public/logo.png)
# New API # New API
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
> [!NOTE] > [!NOTE]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发 > 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
@@ -51,7 +44,7 @@
## 模型支持 ## 模型支持
此版本额外支持以下模型: 此版本额外支持以下模型:
1. 第三方模型 **gps** gpt-4-gizmo-* 1. 第三方模型 **gps** gpt-4-gizmo-*, g-*
2. 智谱glm-4vglm-4v识图 2. 智谱glm-4vglm-4v识图
3. Anthropic Claude 3 3. Anthropic Claude 3
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改 4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
@@ -61,12 +54,10 @@
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md) 8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md) 9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md)
10. Dify 10. Dify
11. Vertex AI目前兼容ClaudeGeminiLlama3.1
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。 您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 比原版One API多出的配置 ## 比原版One API多出的配置
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` - `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。 - `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
@@ -74,7 +65,7 @@
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true` - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。 - `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置 - `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL``STRICT`,默认为 `NONE`
## 部署 ## 部署
### 部署要求 ### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机) - 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
@@ -123,18 +114,23 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
## Suno接口设置文档 ## Suno接口设置文档
[对接文档](Suno.md) [对接文档](Suno.md)
## 界面截图 ## 交流群
![796df8d287b7b7bd7853b2497e7df511](https://github.com/user-attachments/assets/255b5e97-2d3a-4434-b4fa-e922ad88ff5a) <img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
## 界面截图
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/d1ac216e-0804-4105-9fdc-66b35022d861)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/f4f40ed4-8ccb-43d7-a580-90677827646d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/90d7d763-6a77-4b36-9f76-2bb30f18583d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/e414228a-3c35-429a-b298-6451d76d9032)
夜间模式 夜间模式
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## 交流群 ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/5b3228e8-2556-44f7-97d6-4f8d8ee6effa)
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="200"> ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## 相关项目 ## 相关项目
- [One API](https://github.com/songquanpeng/one-api):原版项目 - [One API](https://github.com/songquanpeng/one-api):原版项目

View File

@@ -9,9 +9,20 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// Pay Settings
var StripeApiSecret = ""
var StripeWebhookSecret = ""
var StripePriceId = ""
var PaymentEnabled = false
var StripeUnitPrice = 8.0
var MinTopUp = 5
var StartTime = time.Now().Unix() // unit: second var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API" var SystemName = "New API"
var ServerAddress = "http://localhost:3000"
var OutProxyUrl = ""
var Footer = "" var Footer = ""
var Logo = "" var Logo = ""
var TopUpLink = "" var TopUpLink = ""
@@ -41,10 +52,12 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false var GitHubOAuthEnabled = false
var LinuxDoOAuthEnabled = false
var WeChatAuthEnabled = false var WeChatAuthEnabled = false
var TelegramOAuthEnabled = false var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false var TurnstileCheckEnabled = false
var RegisterEnabled = true var RegisterEnabled = true
var UserSelfDeletionEnabled = false
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制 var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制 var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
@@ -75,6 +88,10 @@ var SMTPToken = ""
var GitHubClientId = "" var GitHubClientId = ""
var GitHubClientSecret = "" var GitHubClientSecret = ""
var LinuxDoClientId = ""
var LinuxDoClientSecret = ""
var LinuxDoMinLevel = 0
var WeChatServerAddress = "" var WeChatServerAddress = ""
var WeChatServerToken = "" var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = "" var WeChatAccountQRCodeImageURL = ""
@@ -112,9 +129,6 @@ var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
const ( const (
RequestIdKey = "X-Oneapi-Request-Id" RequestIdKey = "X-Oneapi-Request-Id"
) )
@@ -126,10 +140,6 @@ const (
RoleRootUser = 100 RoleRootUser = 100
) )
func IsValidateRole(role int) bool {
return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
}
var ( var (
FileUploadPermission = RoleGuestUser FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser FileDownloadPermission = RoleGuestUser
@@ -183,6 +193,12 @@ const (
ChannelStatusAutoDisabled = 3 ChannelStatusAutoDisabled = 3
) )
const (
TopUpStatusPending = "pending"
TopUpStatusSuccess = "success"
TopUpStatusExpired = "expired"
)
const ( const (
ChannelTypeUnknown = 0 ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1 ChannelTypeOpenAI = 1
@@ -220,8 +236,6 @@ const (
ChannelTypeDify = 37 ChannelTypeDify = 37
ChannelTypeJina = 38 ChannelTypeJina = 38
ChannelCloudflare = 39 ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeDummy // this one is only for count, do not add any channel after this ChannelTypeDummy // this one is only for count, do not add any channel after this
@@ -268,6 +282,4 @@ var ChannelBaseURLs = []string{
"", //37 "", //37
"https://api.jina.ai", //38 "https://api.jina.ai", //38
"https://api.cloudflare.com", //39 "https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
} }

View File

@@ -3,7 +3,6 @@ package common
import ( import (
"errors" "errors"
"net/smtp" "net/smtp"
"strings"
) )
type outlookAuth struct { type outlookAuth struct {
@@ -31,10 +30,3 @@ func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
} }
return nil, nil return nil, nil
} }
func isOutlookServer(server string) bool {
// 兼容多地区的outlook邮箱和ofb邮箱
// 其实应该加一个Option来区分是否用LOGIN的方式登录
// 先临时兼容一下
return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
}

View File

@@ -9,26 +9,17 @@ import (
"time" "time"
) )
func generateMessageID() string {
domain := strings.Split(SMTPAccount, "@")[1]
return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain)
}
func SendEmail(subject string, receiver string, content string) error { func SendEmail(subject string, receiver string, content string) error {
if SMTPFrom == "" { // for compatibility if SMTPFrom == "" { // for compatibility
SMTPFrom = SMTPAccount SMTPFrom = SMTPAccount
} }
if SMTPServer == "" && SMTPAccount == "" {
return fmt.Errorf("SMTP 服务器未配置")
}
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
mail := []byte(fmt.Sprintf("To: %s\r\n"+ mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+ "From: %s<%s>\r\n"+
"Subject: %s\r\n"+ "Subject: %s\r\n"+
"Date: %s\r\n"+ "Date: %s\r\n"+
"Message-ID: %s\r\n"+ // 添加 Message-ID 头
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), generateMessageID(), content)) receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")
@@ -71,7 +62,7 @@ func SendEmail(subject string, receiver string, content string) error {
if err != nil { if err != nil {
return err return err
} }
} else if isOutlookServer(SMTPAccount) { } else if strings.HasSuffix(SMTPAccount, "outlook.com") {
auth = LoginAuth(SMTPAccount, SMTPToken) auth = LoginAuth(SMTPAccount, SMTPToken)
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
} else { } else {

84
common/hash.go Normal file
View File

@@ -0,0 +1,84 @@
package common
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"math/rand"
"time"
)
func Sha256Raw(data string) []byte {
h := sha256.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1Raw(data []byte) []byte {
h := sha1.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1(data string) string {
return hex.EncodeToString(Sha1Raw([]byte(data)))
}
func HmacSha256Raw(message, key []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(message)
return h.Sum(nil)
}
func HmacSha256(message, key string) string {
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
}
func RandomBytes(length int) []byte {
rand.Seed(time.Now().UnixNano())
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return b
}
func RandomString(length int) string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomHex(length int) string {
const chars = "abcdef0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomNumber(length int) string {
const chars = "0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomUUID() string {
all := RandomHex(32)
return all[:8] + "-" + all[8:12] + "-" + all[12:16] + "-" + all[16:20] + "-" + all[20:]
}

View File

@@ -1,4 +1,4 @@
package service package common
import ( import (
"bytes" "bytes"
@@ -8,7 +8,6 @@ import (
"golang.org/x/image/webp" "golang.org/x/image/webp"
"image" "image"
"io" "io"
"one-api/common"
"strings" "strings"
) )
@@ -31,9 +30,24 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
return config, format, base64String, err return config, format, base64String, err
} }
func IsImageUrl(url string) (bool, error) {
resp, err := ProxiedHttpHead(url, OutProxyUrl)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
// GetImageFromUrl 获取图片的类型和base64编码的数据 // GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) { func GetImageFromUrl(url string) (mimeType string, data string, err error) {
resp, err := DoImageRequest(url) isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := ProxiedHttpGet(url, OutProxyUrl)
if err != nil { if err != nil {
return return
} }
@@ -52,9 +66,9 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
} }
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := DoImageRequest(imageUrl) response, err := ProxiedHttpGet(imageUrl, OutProxyUrl)
if err != nil { if err != nil {
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err return image.Config{}, "", err
} }
defer response.Body.Close() defer response.Body.Close()
@@ -66,7 +80,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
var readData []byte var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
// 从response.Body读取更多的数据直到达到当前的限制 // 从response.Body读取更多的数据直到达到当前的限制
additionalData := make([]byte, limit-int64(len(readData))) additionalData := make([]byte, limit-int64(len(readData)))
@@ -92,11 +106,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
config, format, err := image.DecodeConfig(reader) config, format, err := image.DecodeConfig(reader)
if err != nil { if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
common.SysLog(err.Error()) SysLog(err.Error())
config, err = webp.DecodeConfig(reader) config, err = webp.DecodeConfig(reader)
if err != nil { if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
common.SysLog(err.Error()) SysLog(err.Error())
} }
format = "webp" format = "webp"
} }

View File

@@ -2,7 +2,6 @@ package common
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
@@ -100,12 +99,10 @@ func LogQuota(quota int) string {
} }
} }
// LogJson 仅供测试使用 only for test func LogQuotaF(quota float64) string {
func LogJson(ctx context.Context, msg string, obj any) { if DisplayInCurrencyEnabled {
jsonStr, err := json.Marshal(obj) return fmt.Sprintf("%.6f 额度", quota/QuotaPerUnit)
if err != nil { } else {
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) return fmt.Sprintf("%d 点额度", int64(quota))
return
} }
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
} }

View File

@@ -23,48 +23,40 @@ const (
var defaultModelRatio = map[string]float64{ var defaultModelRatio = map[string]float64{
//"midjourney": 50, //"midjourney": 50,
"gpt-4-gizmo-*": 15, "gpt-4-gizmo-*": 15,
"gpt-4o-gizmo-*": 2.5, "g-*": 15,
"gpt-4-all": 15, "gpt-4": 15,
"gpt-4o-all": 15, "gpt-4-0314": 15,
"gpt-4": 15, "gpt-4-0613": 15,
//"gpt-4-0314": 15, //deprecated "gpt-4-32k": 30,
"gpt-4-0613": 15, "gpt-4-32k-0314": 30,
"gpt-4-32k": 30, "gpt-4-32k-0613": 30,
//"gpt-4-32k-0314": 30, //deprecated "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.01 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
"o1-preview": 7.5,
"o1-preview-2024-09-12": 7.5,
"o1-mini": 1.5,
"o1-mini-2024-09-12": 1.5,
"gpt-4o-mini": 0.075,
"gpt-4o-mini-2024-07-18": 0.075, "gpt-4o-mini-2024-07-18": 0.075,
"gpt-4-turbo": 5, // $0.01 / 1K tokens "gpt-4o": 2.5, // $0.005 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens "gpt-4-turbo": 5, // $0.01 / 1K tokens
//"gpt-3.5-turbo-0301": 0.75, //deprecated "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo-0613": 0.75, "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5, "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo-0125": 0.25, "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens "gpt-3.5-turbo-0301": 0.75,
"davinci-002": 1, // $0.002 / 1K tokens "gpt-3.5-turbo-0613": 0.75,
"text-ada-001": 0.2, "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"text-babbage-001": 0.25, "gpt-3.5-turbo-16k-0613": 1.5,
"text-curie-001": 1, "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
//"text-davinci-002": 10, "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
//"text-davinci-003": 10, "gpt-3.5-turbo-0125": 0.25,
"babbage-002": 0.2, // $0.0004 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
@@ -82,13 +74,13 @@ var defaultModelRatio = map[string]float64{
"text-search-ada-doc-001": 10, "text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1, "text-moderation-stable": 0.1,
"text-moderation-latest": 0.1, "text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens "claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens "claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens "claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens "claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens "claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5, "claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens "claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB, "ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB, "ERNIE-3.5-8K-0205": 0.024 * RMB,
@@ -110,10 +102,8 @@ var defaultModelRatio = map[string]float64{
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1, "gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1, "gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens "gemini-1.5-pro-latest": 1,
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1, "gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1, "gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1, "gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1, "gemini-ultra": 1,
@@ -125,13 +115,6 @@ var defaultModelRatio = map[string]float64{
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens "glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens "glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572, "glm-3-turbo": 0.3572,
"glm-4-plus": 0.05 * RMB,
"glm-4-0520": 0.1 * RMB,
"glm-4-air": 0.001 * RMB,
"glm-4-airx": 0.01 * RMB,
"glm-4-long": 0.001 * RMB,
"glm-4-flash": 0,
"glm-4v-plus": 0.01 * RMB,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens "qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
@@ -150,28 +133,26 @@ var defaultModelRatio = map[string]float64{
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元 // https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格 // 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.18, "yi-34b-chat-0205": 0.18,
"yi-34b-chat-200k": 0.864, "yi-34b-chat-200k": 0.864,
"yi-vl-plus": 0.432, "yi-vl-plus": 0.432,
"yi-large": 20.0 / 1000 * RMB, "yi-large": 20.0 / 1000 * RMB,
"yi-medium": 2.5 / 1000 * RMB, "yi-medium": 2.5 / 1000 * RMB,
"yi-vision": 6.0 / 1000 * RMB, "yi-vision": 6.0 / 1000 * RMB,
"yi-medium-200k": 12.0 / 1000 * RMB, "yi-medium-200k": 12.0 / 1000 * RMB,
"yi-spark": 1.0 / 1000 * RMB, "yi-spark": 1.0 / 1000 * RMB,
"yi-large-rag": 25.0 / 1000 * RMB, "yi-large-rag": 25.0 / 1000 * RMB,
"yi-large-turbo": 12.0 / 1000 * RMB, "yi-large-turbo": 12.0 / 1000 * RMB,
"yi-large-preview": 20.0 / 1000 * RMB, "yi-large-preview": 20.0 / 1000 * RMB,
"yi-large-rag-preview": 25.0 / 1000 * RMB, "yi-large-rag-preview": 25.0 / 1000 * RMB,
"command": 0.5, "command": 0.5,
"command-nightly": 0.5, "command-nightly": 0.5,
"command-light": 0.5, "command-light": 0.5,
"command-light-nightly": 0.5, "command-light-nightly": 0.5,
"command-r": 0.25, "command-r": 0.25,
"command-r-plus": 1.5, "command-r-plus ": 1.5,
"command-r-08-2024": 0.075, "deepseek-chat": 0.07,
"command-r-plus-08-2024": 1.25, "deepseek-coder": 0.07,
"deepseek-chat": 0.07,
"deepseek-coder": 0.07,
// Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用 // Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用
"llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD, "llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD,
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD, "llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
@@ -182,8 +163,10 @@ var defaultModelRatio = map[string]float64{
var defaultModelPrice = map[string]float64{ var defaultModelPrice = map[string]float64{
"suno_music": 0.1, "suno_music": 0.1,
"suno_lyrics": 0.01, "suno_lyrics": 0.01,
"dall-e-2": 0.02,
"dall-e-3": 0.04, "dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1, "gpt-4-gizmo-*": 0.1,
"g-*": 0.1,
"mj_imagine": 0.1, "mj_imagine": 0.1,
"mj_variation": 0.1, "mj_variation": 0.1,
"mj_reroll": 0.1, "mj_reroll": 0.1,
@@ -203,8 +186,8 @@ var defaultModelPrice = map[string]float64{
} }
var ( var (
modelPriceMap map[string]float64 = nil modelPriceMap = make(map[string]float64)
modelPriceMapMutex = sync.RWMutex{} modelPriceMapMutex = sync.RWMutex{}
) )
var ( var (
modelRatioMap map[string]float64 = nil modelRatioMap map[string]float64 = nil
@@ -213,9 +196,10 @@ var (
var CompletionRatio map[string]float64 = nil var CompletionRatio map[string]float64 = nil
var defaultCompletionRatio = map[string]float64{ var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2, "gpt-4-gizmo-*": 2,
"gpt-4o-gizmo-*": 3, "g-*": 2,
"gpt-4-all": 2, "gpt-4-all": 2,
"gpt-4o-all": 2,
} }
func GetModelPriceMap() map[string]float64 { func GetModelPriceMap() map[string]float64 {
@@ -248,9 +232,8 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
GetModelPriceMap() GetModelPriceMap()
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} } else if strings.HasPrefix(name, "g-") {
if strings.HasPrefix(name, "gpt-4o-gizmo") { name = "g-*"
name = "gpt-4o-gizmo-*"
} }
price, ok := modelPriceMap[name] price, ok := modelPriceMap[name]
if !ok { if !ok {
@@ -291,6 +274,8 @@ func GetModelRatio(name string) float64 {
GetModelRatioMap() GetModelRatioMap()
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
} }
ratio, ok := modelRatioMap[name] ratio, ok := modelRatioMap[name]
if !ok { if !ok {
@@ -331,51 +316,44 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
func GetCompletionRatio(name string) float64 { func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} } else if strings.HasPrefix(name, "g-") {
if strings.HasPrefix(name, "gpt-4o-gizmo") { name = "g-*"
name = "gpt-4o-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
return 3
}
if strings.HasPrefix(name, "gpt-4o") {
if name == "gpt-4o-2024-05-13" {
return 3
}
return 4
}
return 2
}
if strings.HasPrefix(name, "o1-") {
return 4
}
if name == "chatgpt-4o-latest" {
return 4
}
if strings.Contains(name, "claude-instant-1") {
return 3
} else if strings.Contains(name, "claude-2") {
return 3
} else if strings.Contains(name, "claude-3") {
return 5
} }
if strings.HasPrefix(name, "gpt-3.5") { if strings.HasPrefix(name, "gpt-3.5") {
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { if strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
return 3 return 3
} }
if strings.HasSuffix(name, "1106") { if strings.HasSuffix(name, "1106") {
return 2 return 2
} }
if name == "gpt-3.5-turbo" {
return 3
}
return 4.0 / 3.0 return 4.0 / 3.0
} }
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" {
if strings.HasPrefix(name, "gpt-4o-mini") {
return 4
}
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || strings.HasPrefix(name, "gpt-4o") {
return 3
}
return 2
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3
} else if strings.HasPrefix(name, "claude-2") {
return 3
} else if strings.HasPrefix(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "mistral-") { if strings.HasPrefix(name, "mistral-") {
return 3 return 3
} }
if strings.HasPrefix(name, "gemini-") { if strings.HasPrefix(name, "gemini-") {
return 4 return 3
} }
if strings.HasPrefix(name, "command") { if strings.HasPrefix(name, "command") {
switch name { switch name {
@@ -383,10 +361,6 @@ func GetCompletionRatio(name string) float64 {
return 3 return 3
case "command-r-plus": case "command-r-plus":
return 5 return 5
case "command-r-08-2024":
return 4
case "command-r-plus-08-2024":
return 4
default: default:
return 2 return 2
} }

View File

@@ -31,6 +31,14 @@ func MapToJsonStr(m map[string]interface{}) string {
return string(bytes) return string(bytes)
} }
func MapToJsonStrFloat(m map[string]float64) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func StrToMap(str string) map[string]interface{} { func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{}) m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m) err := json.Unmarshal([]byte(str), &m)
@@ -40,11 +48,6 @@ func StrToMap(str string) map[string]interface{} {
return m return m
} }
func IsJsonStr(str string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func String2Int(str string) int { func String2Int(str string) int {
num, err := strconv.Atoi(str) num, err := strconv.Atoi(str)
if err != nil { if err != nil {

View File

@@ -1,23 +0,0 @@
package common
import (
"encoding/json"
)
var UserUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(UserUsableGroups)
if err != nil {
SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
UserUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
}

View File

@@ -1,15 +1,17 @@
package common package common
import ( import (
crand "crypto/rand" "context"
"encoding/base64" "errors"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/net/proxy"
"html/template" "html/template"
"log" "log"
"math/big"
"math/rand" "math/rand"
"net" "net"
"net/http"
"net/url"
"os/exec" "os/exec"
"runtime" "runtime"
"strconv" "strconv"
@@ -131,11 +133,6 @@ func IntMax(a int, b int) int {
} }
} }
func IsIP(s string) bool {
ip := net.ParseIP(s)
return ip != nil
}
func GetUUID() string { func GetUUID() string {
code := uuid.New().String() code := uuid.New().String()
code = strings.Replace(code, "-", "", -1) code = strings.Replace(code, "-", "", -1)
@@ -145,35 +142,24 @@ func GetUUID() string {
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() { func init() {
rand.New(rand.NewSource(time.Now().UnixNano())) rand.Seed(time.Now().UnixNano())
} }
func GenerateRandomCharsKey(length int) (string, error) { func GenerateKey() string {
b := make([]byte, length)
maxI := big.NewInt(int64(len(keyChars)))
for i := range b {
n, err := crand.Int(crand.Reader, maxI)
if err != nil {
return "", err
}
b[i] = keyChars[n.Int64()]
}
return string(b), nil
}
func GenerateRandomKey(length int) (string, error) {
bytes := make([]byte, length*3/4) // 对于48位的输出这里应该是36
if _, err := crand.Read(bytes); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(bytes), nil
}
func GenerateKey() (string, error) {
//rand.Seed(time.Now().UnixNano()) //rand.Seed(time.Now().UnixNano())
return GenerateRandomCharsKey(48) key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
} }
func GetRandomInt(max int) int { func GetRandomInt(max int) int {
@@ -206,3 +192,56 @@ func RandomSleep() {
// Sleep for 0-3000 ms // Sleep for 0-3000 ms
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
} }
func GetProxiedHttpClient(proxyUrl string) (*http.Client, error) {
if "" == proxyUrl {
return &http.Client{}, nil
}
u, err := url.Parse(proxyUrl)
if err != nil {
return nil, err
}
if strings.HasPrefix(proxyUrl, "http") {
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(u),
},
}, nil
} else if strings.HasPrefix(proxyUrl, "socks") {
dialer, err := proxy.FromURL(u, proxy.Direct)
if err != nil {
return nil, err
}
return &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
},
},
}, nil
}
return nil, errors.New("unsupported proxy type")
}
func ProxiedHttpGet(url, proxyUrl string) (*http.Response, error) {
client, err := GetProxiedHttpClient(proxyUrl)
if err != nil {
return nil, err
}
return client.Get(url)
}
func ProxiedHttpHead(url, proxyUrl string) (*http.Response, error) {
client, err := GetProxiedHttpClient(proxyUrl)
if err != nil {
return nil, err
}
return client.Head(url)
}

View File

@@ -20,16 +20,14 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{ var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta", "gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta", "gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta", "gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta", "gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-pro-exp-0827": "v1beta", "gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta", "gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash-exp-0827": "v1beta", "gemini-1.5-flash": "v1beta",
"gemini-1.5-flash-001": "v1beta", "gemini-ultra": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
} }
func InitEnv() { func InitEnv() {
@@ -46,6 +44,3 @@ func InitEnv() {
} }
} }
} }
// 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

View File

@@ -1,9 +0,0 @@
package constant
var ServerAddress = "http://localhost:3000"
var WorkerUrl = ""
var WorkerValidKey = ""
func EnableWorker() bool {
return WorkerUrl != ""
}

View File

@@ -20,7 +20,6 @@ import (
"one-api/relay/constant" "one-api/relay/constant"
"one-api/service" "one-api/service"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@@ -82,7 +81,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
request := buildTestRequest(testModel) request := buildTestRequest()
request.Model = testModel
meta.UpstreamModelName = testModel meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
@@ -141,22 +141,17 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return nil, nil return nil, nil
} }
func buildTestRequest(model string) *dto.GeneralOpenAIRequest { func buildTestRequest() *dto.GeneralOpenAIRequest {
testRequest := &dto.GeneralOpenAIRequest{ testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later Model: "", // this will be set later
Stream: false, MaxTokens: 1,
} Stream: false,
if strings.HasPrefix(model, "o1-") {
testRequest.MaxCompletionTokens = 1
} else {
testRequest.MaxTokens = 1
} }
content, _ := json.Marshal("hi") content, _ := json.Marshal("hi")
testMessage := dto.Message{ testMessage := dto.Message{
Role: "user", Role: "user",
Content: content, Content: content,
} }
testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage) testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest return testRequest
} }
@@ -231,22 +226,26 @@ func testAllChannels(notify bool) error {
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
shouldBanChannel := false ban := false
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
ban = true
}
// request error disables the channel // request error disables the channel
if openaiWithStatusErr != nil { if openaiWithStatusErr != nil {
oaiErr := openaiWithStatusErr.Error oaiErr := openaiWithStatusErr.Error
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
} }
if milliseconds > disableThreshold { // parse *int to bool
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) if !channel.GetAutoBan() {
shouldBanChannel = true ban = false
} }
// disable channel // disable channel
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { if ban && isChannelEnabled {
service.DisableChannel(channel.Id, channel.Name, err.Error()) service.DisableChannel(channel.Id, channel.Name, err.Error())
} }

View File

@@ -198,28 +198,6 @@ func AddChannel(c *gin.Context) {
} }
channel.CreatedTime = common.GetTimestamp() channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n") keys := strings.Split(channel.Key, "\n")
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
keys = []string{channel.Key}
}
channels := make([]model.Channel, 0, len(keys)) channels := make([]model.Channel, 0, len(keys))
for _, key := range keys { for _, key := range keys {
if key == "" { if key == "" {
@@ -319,27 +297,6 @@ func UpdateChannel(c *gin.Context) {
}) })
return return
} }
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区不能为空",
})
return
} else {
if common.IsJsonStr(channel.Other) {
// must have default
regionMap := common.StrToMap(channel.Other)
if regionMap["default"] == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "部署地区必须包含default字段",
})
return
}
}
}
}
err = channel.Update() err = channel.Update()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -112,9 +112,7 @@ func GitHubOAuth(c *gin.Context) {
user := model.User{ user := model.User{
GitHubId: githubUser.Login, GitHubId: githubUser.Login,
} }
// IsGitHubIdAlreadyTaken is unscoped
if model.IsGitHubIdAlreadyTaken(user.GitHubId) { if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
// FillUserByGitHubId is scoped
err := user.FillUserByGitHubId() err := user.FillUserByGitHubId()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -123,16 +121,10 @@ func GitHubOAuth(c *gin.Context) {
}) })
return return
} }
// if user.Id == 0 , user has been deleted
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else { } else {
if common.RegisterEnabled { if common.RegisterEnabled {
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" { if githubUser.Name != "" {
user.DisplayName = githubUser.Name user.DisplayName = githubUser.Name
@@ -143,7 +135,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = common.RoleCommonUser user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -17,18 +17,3 @@ func GetGroups(c *gin.Context) {
"data": groupNames, "data": groupNames,
}) })
} }
func GetUserGroups(c *gin.Context) {
usableGroups := make(map[string]string)
for groupName, _ := range common.GroupRatio {
// UserUsableGroups contains the groups that the user can use
if _, ok := common.UserUsableGroups[groupName]; ok {
usableGroups[groupName] = common.UserUsableGroups[groupName]
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": usableGroups,
})
}

239
controller/linuxdo.go Normal file
View File

@@ -0,0 +1,239 @@
package controller
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
type LinuxDoOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type LinuxDoUser struct {
ID int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
form := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
}
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", "Basic "+auth)
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse LinuxDoOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res2.Body.Close()
var linuxdoUser LinuxDoUser
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
if err != nil {
return nil, err
}
if linuxdoUser.ID == 0 {
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
}
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
return nil, errors.New("用户 LINUX DO 信任等级不足!")
}
return &linuxdoUser, nil
}
func LinuxDoOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
err := user.FillUserByLinuxDoId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
affCode := c.Query("aff")
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
if linuxdoUser.Name != "" {
user.DisplayName = linuxdoUser.Name
} else {
user.DisplayName = linuxdoUser.Username
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 LINUX DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -1,19 +1,18 @@
package controller package controller
import ( import (
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"github.com/gin-gonic/gin"
) )
func GetAllLogs(c *gin.Context) { func GetAllLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size")) pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 { if p < 0 {
p = 1 p = 0
} }
if pageSize < 0 { if pageSize < 0 {
pageSize = common.ItemsPerPage pageSize = common.ItemsPerPage
@@ -25,7 +24,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) channel, _ := strconv.Atoi(c.Query("channel"))
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel) logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -36,20 +35,17 @@ func GetAllLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": map[string]any{ "total": total,
"items": logs, "data": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
}) })
return
} }
func GetUserLogs(c *gin.Context) { func GetUserLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size")) pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 { if p < 0 {
p = 1 p = 0
} }
if pageSize < 0 { if pageSize < 0 {
pageSize = common.ItemsPerPage pageSize = common.ItemsPerPage
@@ -63,7 +59,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize) logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -74,12 +70,8 @@ func GetUserLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": map[string]any{ "total": total,
"items": logs, "data": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
}) })
return return
} }

View File

@@ -233,7 +233,7 @@ func GetAllMidjourney(c *gin.Context) {
} }
if constant.MjForwardUrlEnabled { if constant.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney
} }
} }
@@ -265,7 +265,7 @@ func GetUserMidjourney(c *gin.Context) {
} }
if constant.MjForwardUrlEnabled { if constant.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney
} }
} }

View File

@@ -38,6 +38,8 @@ func GetStatus(c *gin.Context) {
"email_verification": common.EmailVerificationEnabled, "email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled, "github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId, "github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
"linuxdo_client_id": common.LinuxDoClientId,
"telegram_oauth": common.TelegramOAuthEnabled, "telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName, "telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName, "system_name": common.SystemName,
@@ -45,9 +47,9 @@ func GetStatus(c *gin.Context) {
"footer_html": common.Footer, "footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled, "wechat_login": common.WeChatAuthEnabled,
"server_address": constant.ServerAddress, "server_address": common.ServerAddress,
"price": constant.Price, "stripe_unit_price": common.StripeUnitPrice,
"min_topup": constant.MinTopUp, "min_topup": common.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey, "turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink, "top_up_link": common.TopUpLink,
@@ -61,7 +63,7 @@ func GetStatus(c *gin.Context) {
"enable_data_export": common.DataExportEnabled, "enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime, "data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar, "default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "", "payment_enabled": common.PaymentEnabled,
"mj_notify_enabled": constant.MjNotifyEnabled, "mj_notify_enabled": constant.MjNotifyEnabled,
}, },
}) })
@@ -204,7 +206,7 @@ func SendPasswordResetEmail(c *gin.Context) {
} }
code := common.GenerateVerificationCode(0) code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName) subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ "<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+

View File

@@ -137,63 +137,31 @@ func init() {
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
models := model.GetGroupModels(user.Group)
userOpenAiModels := make([]dto.OpenAIModels, 0) userOpenAiModels := make([]dto.OpenAIModels, 0)
permission := getPermission() permission := getPermission()
for _, s := range models {
modelLimitEnable := c.GetBool("token_model_limit_enabled") if _, ok := openAIModelsMap[s]; ok {
if modelLimitEnable { userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else { } else {
tokenModelLimit = map[string]bool{} userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
} Id: s,
for allowModel, _ := range tokenModelLimit { Object: "model",
if _, ok := openAIModelsMap[allowModel]; ok { Created: 1626777600,
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel]) OwnedBy: "custom",
} else { Permission: permission,
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ Root: s,
Id: allowModel, Parent: nil,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: allowModel,
Parent: nil,
})
}
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "get user group failed",
}) })
return
}
group := userGroup
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
group = tokenGroup
}
models := model.GetGroupModels(group)
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: s,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: s,
Parent: nil,
})
}
} }
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{

View File

@@ -50,6 +50,14 @@ func UpdateOption(c *gin.Context) {
}) })
return return
} }
case "LinuxDoOAuthEnabled":
if option.Value == "true" && common.LinuxDoClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 LINUX DO OAuth请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret",
})
return
}
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -7,11 +7,18 @@ import (
) )
func GetPricing(c *gin.Context) { func GetPricing(c *gin.Context) {
pricing := model.GetPricing() userId := c.GetInt("id")
// if no login, get default group ratio
groupRatio := common.GetGroupRatio("default")
group, err := model.CacheGetUserGroup(userId)
if err == nil {
groupRatio = common.GetGroupRatio(group)
}
pricing := model.GetPricing(group)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"data": pricing, "data": pricing,
"group_ratio": common.GroupRatio, "group_ratio": groupRatio,
}) })
} }

View File

@@ -38,51 +38,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
return err return err
} }
func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode
defer func() {
if openaiErr != nil {
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}()
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
return
}
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
return
}
if playgroundRequest.Model == "" {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
return
}
c.Set("original_model", playgroundRequest.Model)
group := playgroundRequest.Group
if group == "" {
group = c.GetString("group")
} else {
c.Set("group", group)
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
Relay(c)
}
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path) relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)
@@ -166,9 +121,6 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
if openaiErr == nil { if openaiErr == nil {
return false return false
} }
if openaiErr.LocalError {
return false
}
if retryTimes <= 0 { if retryTimes <= 0 {
return false return false
} }
@@ -199,6 +151,9 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
// azure处理超时不重试 // azure处理超时不重试
return false return false
} }
if openaiErr.LocalError {
return false
}
if openaiErr.StatusCode/100 == 2 { if openaiErr.StatusCode/100 == 2 {
return false return false
} }

97
controller/stripe.go Normal file
View File

@@ -0,0 +1,97 @@
package controller
import (
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/webhook"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
func StripeWebhook(c *gin.Context) {
payload, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
c.AbortWithStatus(http.StatusServiceUnavailable)
return
}
signature := c.GetHeader("Stripe-Signature")
endpointSecret := common.StripeWebhookSecret
event, err := webhook.ConstructEvent(payload, signature, endpointSecret)
if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err)
c.AbortWithStatus(http.StatusBadRequest)
return
}
switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(event)
case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(event)
default:
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
}
c.Status(http.StatusOK)
}
func sessionCompleted(event stripe.Event) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "complete" != status {
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
return
}
err := model.Recharge(referenceId, customerId)
if err != nil {
log.Println(err.Error(), referenceId)
return
}
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency"))
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
}
func sessionExpired(event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "expired" != status {
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
return
}
if "" == referenceId {
log.Println("未提供支付单号")
return
}
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Println("充值订单不存在", referenceId)
return
}
if topUp.Status != common.TopUpStatusPending {
log.Println("充值订单状态错误", referenceId)
}
topUp.Status = common.TopUpStatusExpired
err := topUp.Update()
if err != nil {
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
return
}
log.Println("充值订单已过期", referenceId)
}

View File

@@ -5,7 +5,6 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"io" "io"
"net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"sort" "sort"
@@ -49,13 +48,6 @@ func TelegramBind(c *gin.Context) {
}) })
return return
} }
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
user.TelegramId = telegramId user.TelegramId = telegramId
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{

View File

@@ -123,19 +123,10 @@ func AddToken(c *gin.Context) {
}) })
return return
} }
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Key: key, Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(), CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(), AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
@@ -143,8 +134,6 @@ func AddToken(c *gin.Context) {
UnlimitedQuota: token.UnlimitedQuota, UnlimitedQuota: token.UnlimitedQuota,
ModelLimitsEnabled: token.ModelLimitsEnabled, ModelLimitsEnabled: token.ModelLimitsEnabled,
ModelLimits: token.ModelLimits, ModelLimits: token.ModelLimits,
AllowIps: token.AllowIps,
Group: token.Group,
} }
err = cleanToken.Insert() err = cleanToken.Insert()
if err != nil { if err != nil {
@@ -232,8 +221,6 @@ func UpdateToken(c *gin.Context) {
cleanToken.UnlimitedQuota = token.UnlimitedQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
cleanToken.ModelLimits = token.ModelLimits cleanToken.ModelLimits = token.ModelLimits
cleanToken.AllowIps = token.AllowIps
cleanToken.Group = token.Group
} }
err = cleanToken.Update() err = cleanToken.Update()
if err != nil { if err != nil {

View File

@@ -1,22 +1,20 @@
package controller package controller
import "C"
import ( import (
"fmt" "fmt"
"github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/samber/lo" "github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/checkout/session"
"log" "log"
"net/url"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/model" "one-api/model"
"one-api/service"
"strconv" "strconv"
"sync" "strings"
"time" "time"
) )
type EpayRequest struct { type PayRequest struct {
Amount int `json:"amount"` Amount int `json:"amount"`
PaymentMethod string `json:"payment_method"` PaymentMethod string `json:"payment_method"`
TopUpCode string `json:"top_up_code"` TopUpCode string `json:"top_up_code"`
@@ -27,201 +25,114 @@ type AmountRequest struct {
TopUpCode string `json:"top_up_code"` TopUpCode string `json:"top_up_code"`
} }
func GetEpayClient() *epay.Client { func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" { if !strings.HasPrefix(common.StripeApiSecret, "sk_") {
return nil return "", fmt.Errorf("无效的Stripe API密钥")
} }
withUrl, err := epay.NewClient(&epay.Config{
PartnerID: constant.EpayId, stripe.Key = common.StripeApiSecret
Key: constant.EpayKey,
}, constant.PayAddress) params := &stripe.CheckoutSessionParams{
ClientReferenceID: stripe.String(referenceId),
SuccessURL: stripe.String(common.ServerAddress + "/log"),
CancelURL: stripe.String(common.ServerAddress + "/topup"),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(common.StripePriceId),
Quantity: stripe.Int64(amount),
},
},
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
}
if "" == customerId {
if "" != email {
params.CustomerEmail = stripe.String(email)
}
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
} else {
params.Customer = stripe.String(customerId)
}
result, err := session.New(params)
if err != nil { if err != nil {
return nil return "", err
} }
return withUrl
return result.URL, nil
} }
func getPayMoney(amount float64, group string) float64 { func GetPayAmount(count float64) float64 {
if !common.DisplayInCurrencyEnabled { return count * common.StripeUnitPrice
amount = amount / common.QuotaPerUnit
}
// 别问为什么用float64问就是这么点钱没必要
topupGroupRatio := common.GetTopupGroupRatio(group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
payMoney := amount * constant.Price * topupGroupRatio
return payMoney
} }
func getMinTopup() int { func GetChargedAmount(count float64, user model.User) float64 {
minTopup := constant.MinTopUp topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
if !common.DisplayInCurrencyEnabled { if topUpGroupRatio == 0 {
minTopup = minTopup * int(common.QuotaPerUnit) topUpGroupRatio = 1
} }
return minTopup
return count * topUpGroupRatio
} }
func RequestEpay(c *gin.Context) { func RequestPayLink(c *gin.Context) {
var req EpayRequest var req PayRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(200, gin.H{"message": err.Error(), "data": 10})
return return
} }
if req.Amount < getMinTopup() { if !common.PaymentEnabled {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())}) c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
return
}
if req.PaymentMethod != "stripe" {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
return
}
if req.Amount > 10000 {
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
group, err := model.CacheGetUserGroup(id) user, _ := model.GetUserById(id, false)
if err != nil { chargedMoney := GetChargedAmount(float64(req.Amount), *user)
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getPayMoney(float64(req.Amount), group)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
var payType epay.PurchaseType reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), common.RandomString(4))
if req.PaymentMethod == "zfb" { referenceId := "ref_" + common.Sha1(reference)
payType = epay.Alipay
} payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, int64(req.Amount))
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = epay.WechatPay
}
callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType,
ServiceTradeNo: tradeNo,
Name: fmt.Sprintf("TUC%d", req.Amount),
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
Device: epay.PC,
NotifyUrl: notifyUrl,
ReturnUrl: returnUrl,
})
if err != nil { if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return return
} }
amount := req.Amount
if !common.DisplayInCurrencyEnabled {
amount = amount / int(common.QuotaPerUnit)
}
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: amount, Amount: req.Amount,
Money: payMoney, Money: chargedMoney,
TradeNo: tradeNo, TradeNo: referenceId,
CreateTime: time.Now().Unix(), CreateTime: time.Now().Unix(),
Status: "pending", Status: common.TopUpStatusPending,
} }
err = topUp.Insert() err = topUp.Insert()
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
return return
} }
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri}) c.JSON(200, gin.H{
} "message": "success",
"data": gin.H{
// tradeNo lock "payLink": payLink,
var orderLocks sync.Map },
var createLock sync.Mutex })
// LockOrder 尝试对给定订单号加锁
func LockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if !ok {
createLock.Lock()
defer createLock.Unlock()
lock, ok = orderLocks.Load(tradeNo)
if !ok {
lock = new(sync.Mutex)
orderLocks.Store(tradeNo, lock)
}
}
lock.(*sync.Mutex).Lock()
}
// UnlockOrder 释放给定订单号的锁
func UnlockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if ok {
lock.(*sync.Mutex).Unlock()
}
}
func EpayNotify(c *gin.Context) {
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
client := GetEpayClient()
if client == nil {
log.Println("易支付回调失败 未找到配置信息")
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
return
}
}
verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus {
_, err := c.Writer.Write([]byte("success"))
if err != nil {
log.Println("易支付回调写入失败")
}
} else {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
}
log.Println("易支付回调签名验证失败")
return
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp == nil {
log.Printf("易支付回调未找到订单: %v", verifyInfo)
return
}
if topUp.Status == "pending" {
topUp.Status = "success"
err := topUp.Update()
if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp)
return
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
}
} }
func RequestAmount(c *gin.Context) { func RequestAmount(c *gin.Context) {
@@ -231,21 +142,23 @@ func RequestAmount(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return return
} }
if !common.PaymentEnabled {
if req.Amount < getMinTopup() { c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())}) return
}
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
group, err := model.CacheGetUserGroup(id) user, _ := model.GetUserById(id, false)
if err != nil { payMoney := GetPayAmount(float64(req.Amount))
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) chargedMoney := GetChargedAmount(float64(req.Amount), *user)
return c.JSON(200, gin.H{
} "message": "success",
payMoney := getPayMoney(float64(req.Amount), group) "data": gin.H{
if payMoney <= 0.01 { "payAmount": strconv.FormatFloat(payMoney, 'f', 2, 64),
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) "chargedAmount": strconv.FormatFloat(chargedMoney, 'f', 2, 64),
return },
} })
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
} }

View File

@@ -7,12 +7,10 @@ import (
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/constant"
) )
type LoginRequest struct { type LoginRequest struct {
@@ -68,7 +66,7 @@ func setupLogin(user *model.User, c *gin.Context) {
session.Set("username", user.Username) session.Set("username", user.Username)
session.Set("role", user.Role) session.Set("role", user.Role)
session.Set("status", user.Status) session.Set("status", user.Status)
session.Set("group", user.Group) session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
err := session.Save() err := session.Save()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -160,9 +158,8 @@ func Register(c *gin.Context) {
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "数据库错误,请稍后重试", "message": err.Error(),
}) })
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return return
} }
if exist { if exist {
@@ -190,48 +187,6 @@ func Register(c *gin.Context) {
}) })
return return
} }
// 获取插入后的用户ID
var insertedUser model.User
if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户注册失败或用户ID获取失败",
})
return
}
// 生成默认令牌
if constant.GenerateDefaultToken {
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成默认令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
// 生成默认令牌
token := model.Token{
UserId: insertedUser.Id, // 使用插入后的用户ID
Name: cleanUser.Username + "的初始令牌",
Key: key,
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: -1, // 永不过期
RemainQuota: 500000, // 示例额度
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "创建默认令牌失败",
})
return
}
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@@ -322,18 +277,7 @@ func GenerateAccessToken(c *gin.Context) {
}) })
return return
} }
// get rand int 28-32 user.AccessToken = common.GetUUID()
randI := common.GetRandomInt(4)
key, err := common.GenerateRandomKey(29 + randI)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成失败",
})
common.SysError("failed to generate key: " + err.Error())
return
}
user.SetAccessToken(key)
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -574,7 +518,7 @@ func UpdateSelf(c *gin.Context) {
return return
} }
func DeleteUser(c *gin.Context) { func HardDeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -583,7 +527,7 @@ func DeleteUser(c *gin.Context) {
}) })
return return
} }
originUser, err := model.GetUserById(id, false) originUser, err := model.GetUserByIdUnscoped(id, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -607,9 +551,23 @@ func DeleteUser(c *gin.Context) {
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
} }
func DeleteSelf(c *gin.Context) { func DeleteSelf(c *gin.Context) {
if !common.UserSelfDeletionEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "当前设置不允许用户自我删除账号",
})
return
}
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
@@ -639,7 +597,6 @@ func DeleteSelf(c *gin.Context) {
func CreateUser(c *gin.Context) { func CreateUser(c *gin.Context) {
var user model.User var user model.User
err := json.NewDecoder(c.Request.Body).Decode(&user) err := json.NewDecoder(c.Request.Body).Decode(&user)
user.Username = strings.TrimSpace(user.Username)
if err != nil || user.Username == "" || user.Password == "" { if err != nil || user.Username == "" || user.Password == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -687,8 +644,8 @@ func CreateUser(c *gin.Context) {
} }
type ManageRequest struct { type ManageRequest struct {
Id int `json:"id"` Username string `json:"username"`
Action string `json:"action"` Action string `json:"action"`
} }
// ManageUser Only admin user can do this // ManageUser Only admin user can do this
@@ -704,7 +661,7 @@ func ManageUser(c *gin.Context) {
return return
} }
user := model.User{ user := model.User{
Id: req.Id, Username: req.Username,
} }
// Fill attributes // Fill attributes
model.DB.Unscoped().Where(&user).First(&user) model.DB.Unscoped().Where(&user).First(&user)

View File

@@ -78,13 +78,6 @@ func WeChatAuth(c *gin.Context) {
}) })
return return
} }
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else { } else {
if common.RegisterEnabled { if common.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)

View File

@@ -2,18 +2,17 @@ version: '3.4'
services: services:
new-api: new-api:
image: calciumion/new-api:latest image: pengzhile/new-api:latest
# build: .
container_name: new-api container_name: new-api
restart: always restart: always
command: --log-dir /app/logs command: --log-dir /app/logs
ports: ports:
- "3000:3000" - "3000:3000"
volumes: volumes:
- ./data:/data - ./data/new-api:/data
- ./logs:/app/logs - ./logs:/app/logs
environment: environment:
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库 - SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串 - SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
@@ -23,13 +22,22 @@ services:
depends_on: depends_on:
- redis - redis
healthcheck: - db
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
timeout: 10s
retries: 3
redis: redis:
image: redis:latest image: redis:latest
container_name: redis container_name: redis
restart: always restart: always
db:
image: mysql:8.2.0
container_name: mysql
restart: always
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: newapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: new-api # 自动创建数据库

View File

@@ -1,6 +0,0 @@
package dto
type PlayGroundRequest struct {
Model string `json:"model,omitempty"`
Group string `json:"group,omitempty"`
}

View File

@@ -1,17 +1,14 @@
package dto package dto
type RerankRequest struct { type RerankRequest struct {
Documents []any `json:"documents"` Documents []any `json:"documents"`
Query string `json:"query"` Query string `json:"query"`
Model string `json:"model"` Model string `json:"model"`
TopN int `json:"top_n"` TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`
} }
type RerankResponseDocument struct { type RerankResponseDocument struct {
Document any `json:"document,omitempty"` Document any `json:"document"`
Index int `json:"index"` Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"` RelevanceScore float64 `json:"relevance_score"`
} }

View File

@@ -7,32 +7,35 @@ type ResponseFormat struct {
} }
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"` BestOf int `json:"best_of,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` Echo bool `json:"echo,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` Stream bool `json:"stream,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Suffix string `json:"suffix,omitempty"`
TopP float64 `json:"top_p,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
TopK int `json:"top_k,omitempty"` Temperature float64 `json:"temperature,omitempty"`
Stop any `json:"stop,omitempty"` TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"` TopK int `json:"top_k,omitempty"`
Input any `json:"input,omitempty"` Stop any `json:"stop,omitempty"`
Instruction string `json:"instruction,omitempty"` N int `json:"n,omitempty"`
Size string `json:"size,omitempty"` Input any `json:"input,omitempty"`
Functions any `json:"functions,omitempty"` Instruction string `json:"instruction,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` Size string `json:"size,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` Functions any `json:"functions,omitempty"`
ResponseFormat any `json:"response_format,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
Seed float64 `json:"seed,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
Tools []ToolCall `json:"tools,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` Seed float64 `json:"seed,omitempty"`
User string `json:"user,omitempty"` Tools []ToolCall `json:"tools,omitempty"`
LogProbs bool `json:"logprobs,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"` User string `json:"user,omitempty"`
Dimensions int `json:"dimensions,omitempty"` LogitBias any `json:"logit_bias,omitempty"`
LogProbs any `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
} }
type OpenAITools struct { type OpenAITools struct {

View File

@@ -34,7 +34,6 @@ type OpenAITextResponseChoice struct {
type OpenAITextResponse struct { type OpenAITextResponse struct {
Id string `json:"id"` Id string `json:"id"`
Model string `json:"model"`
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"` Choices []OpenAITextResponseChoice `json:"choices"`

25
go.mod
View File

@@ -1,17 +1,13 @@
module one-api module one-api
// +heroku goVersion go1.18 // +heroku goVersion go1.18
go 1.21 go 1.18
toolchain go1.22.4
require ( require (
github.com/Calcium-Ion/go-epay v0.0.2
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sessions v0.0.5
@@ -27,7 +23,8 @@ require (
github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0 github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible github.com/shirou/gopsutil v3.21.11+incompatible
golang.org/x/crypto v0.26.0 github.com/stripe/stripe-go/v76 v76.21.0
golang.org/x/crypto v0.21.0
golang.org/x/image v0.15.0 golang.org/x/image v0.15.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2 gorm.io/driver/postgres v1.5.2
@@ -41,8 +38,9 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect
@@ -53,7 +51,6 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect
@@ -68,11 +65,9 @@ require (
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/stretchr/testify v1.9.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
@@ -80,10 +75,10 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.28.0 // indirect golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.8.0 // indirect golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.24.0 // indirect golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

56
go.sum
View File

@@ -1,5 +1,3 @@
github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
@@ -23,8 +21,8 @@ github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaU
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -34,10 +32,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
@@ -58,7 +57,6 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -83,9 +81,10 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -129,6 +128,8 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/linux-do/tiktoken-go v0.7.0 h1:Kcm/miJ5gp77srtF8GQWnfq7W9kTaXEuHZg/g9IVEu8=
github.com/linux-do/tiktoken-go v0.7.0/go.mod h1:9Vkdtp0ngi4USmrdSx984iuIQ5IMr0hnUdz4jZZTJb8=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@@ -136,8 +137,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -145,17 +144,16 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -178,8 +176,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stripe/stripe-go/v76 v76.21.0 h1:O3GHImHS4oUI3qWMOClHN3zAQF5/oswS/NB7leV1fsU=
github.com/stripe/stripe-go/v76 v76.21.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@@ -198,21 +197,23 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -221,27 +222,26 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=

View File

@@ -42,11 +42,6 @@ func main() {
if err != nil { if err != nil {
common.FatalLog("failed to initialize database: " + err.Error()) common.FatalLog("failed to initialize database: " + err.Error())
} }
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
defer func() { defer func() {
err := model.CloseDB() err := model.CloseDB()
if err != nil { if err != nil {

View File

@@ -10,23 +10,13 @@ import (
"strings" "strings"
) )
func validUserInfo(username string, role int) bool {
// check username is empty
if strings.TrimSpace(username) == "" {
return false
}
if !common.IsValidateRole(role) {
return false
}
return true
}
func authHelper(c *gin.Context, minRole int) { func authHelper(c *gin.Context, minRole int) {
session := sessions.Default(c) session := sessions.Default(c)
username := session.Get("username") username := session.Get("username")
role := session.Get("role") role := session.Get("role")
id := session.Get("id") id := session.Get("id")
status := session.Get("status") status := session.Get("status")
linuxDoEnable := session.Get("linuxdo_enable")
useAccessToken := false useAccessToken := false
if username == nil { if username == nil {
// Check access token // Check access token
@@ -41,19 +31,12 @@ func authHelper(c *gin.Context, minRole int) {
} }
user := model.ValidateAccessToken(accessToken) user := model.ValidateAccessToken(accessToken)
if user != nil && user.Username != "" { if user != nil && user.Username != "" {
if !validUserInfo(user.Username, user.Role) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
// Token is valid // Token is valid
username = user.Username username = user.Username
role = user.Role role = user.Role
id = user.Id id = user.Id
status = user.Status status = user.Status
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
useAccessToken = true useAccessToken = true
} else { } else {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -102,6 +85,14 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort() c.Abort()
return return
} }
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户 LINUX DO 信任等级不足",
})
c.Abort()
return
}
if role.(int) < minRole { if role.(int) < minRole {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -110,19 +101,9 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort() c.Abort()
return return
} }
if !validUserInfo(username.(string), role.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
c.Set("username", username) c.Set("username", username)
c.Set("role", role) c.Set("role", role)
c.Set("id", id) c.Set("id", id)
c.Set("group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
c.Next() c.Next()
} }
@@ -175,7 +156,7 @@ func TokenAuth() func(c *gin.Context) {
if token != nil { if token != nil {
id := c.GetInt("id") id := c.GetInt("id")
if id == 0 { if id == 0 {
c.Set("id", token.UserId) c.Set("id", token.Id)
} }
} }
if err != nil { if err != nil {
@@ -191,6 +172,15 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return return
} }
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !linuxDoEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
return
}
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
@@ -204,8 +194,6 @@ func TokenAuth() func(c *gin.Context) {
} else { } else {
c.Set("token_model_limit_enabled", false) c.Set("token_model_limit_enabled", false)
} }
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1]) c.Set("specific_channel_id", parts[1])

View File

@@ -22,14 +22,6 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
allowIpsMap := c.GetStringMap("allow_ips")
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
userId := c.GetInt("id") userId := c.GetInt("id")
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("specific_channel_id") channelId, ok := c.Get("specific_channel_id")
@@ -39,20 +31,6 @@ func Distribute() func(c *gin.Context) {
return return
} }
userGroup, _ := model.CacheGetUserGroup(userId) userGroup, _ := model.CacheGetUserGroup(userId)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := common.UserUsableGroups[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if _, ok := common.GroupRatio[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
userGroup = tokenGroup
}
c.Set("group", userGroup) c.Set("group", userGroup)
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
@@ -221,8 +199,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeVertexAi:
c.Set("region", channel.Other)
case common.ChannelTypeXunfei: case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeGemini: case common.ChannelTypeGemini:

View File

@@ -36,12 +36,6 @@ func GetEnabledModels() []string {
return models return models
} }
func GetAllEnableAbilities() []Ability {
var abilities []Ability
DB.Find(&abilities, "enabled = ?", true)
return abilities
}
func getPriority(group string, model string, retry int) (int, error) { func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`" groupCol := "`group`"
trueVal := "1" trueVal := "1"

View File

@@ -205,6 +205,30 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err return userEnabled, err
} }
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
if !common.RedisEnabled {
return IsLinuxDoEnabled(userId)
}
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if linuxDoEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set linuxdo enabled error: " + err.Error())
}
return linuxDoEnabled, err
}
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex var channelSyncLock sync.RWMutex
@@ -269,9 +293,8 @@ func SyncChannelCache(frequency int) {
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") { if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*" model = "gpt-4-gizmo-*"
} } else if strings.HasPrefix(model, "g-") {
if strings.HasPrefix(model, "gpt-4o-gizmo") { model = "g-*"
model = "gpt-4o-gizmo-*"
} }
// if memory cache is disabled, get channel directly from database // if memory cache is disabled, get channel directly from database

View File

@@ -106,23 +106,16 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
// 构造WHERE子句 // 构造WHERE子句
var whereClause string var whereClause string
var args []interface{} var args []interface{}
if group != "" && group != "null" { if group != "" {
var groupCondition string whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " = ? AND " + modelsCol + " LIKE ?"
if common.UsingMySQL { args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, group, "%"+model+"%")
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%", "%,"+group+",%")
} else { } else {
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?" whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%") args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
} }
// 执行查询 // 执行查询
err := baseQuery.Where(whereClause, args...).Order("priority desc").Find(&channels).Error err := baseQuery.Where(whereClause, args...).Find(&channels).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -3,12 +3,11 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common" "one-api/common"
"strings" "strings"
"time" "time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
) )
type Log struct { type Log struct {
@@ -39,7 +38,7 @@ const (
) )
func GetLogByKey(key string) (logs []*Log, err error) { func GetLogByKey(key string) (logs []*Log, err error) {
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error err = DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
return logs, err return logs, err
} }
@@ -55,7 +54,7 @@ func RecordLog(userId int, logType int, content string) {
Type: logType, Type: logType,
Content: content, Content: content,
} }
err := LOG_DB.Create(log).Error err := DB.Create(log).Error
if err != nil { if err != nil {
common.SysError("failed to record log: " + err.Error()) common.SysError("failed to record log: " + err.Error())
} }
@@ -85,7 +84,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
IsStream: isStream, IsStream: isStream,
Other: otherStr, Other: otherStr,
} }
err := LOG_DB.Create(log).Error err := DB.Create(log).Error
if err != nil { if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error()) common.LogError(ctx, "failed to record log: "+err.Error())
} }
@@ -99,9 +98,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) { func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = LOG_DB tx = DB
} else { } else {
tx = LOG_DB.Where("type = ?", logType) tx = DB.Where("type = ?", logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name like ?", modelName) tx = tx.Where("model_name like ?", modelName)
@@ -121,23 +120,22 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if channel != 0 { if channel != 0 {
tx = tx.Where("channel_id = ?", channel) tx = tx.Where("channel_id = ?", channel)
} }
err = tx.Model(&Log{}).Count(&total).Error err = tx.Model(&Log{}).Count(&total).Error
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
return nil, 0, err
}
return logs, total, err return logs, total, err
} }
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) { func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = LOG_DB.Where("user_id = ?", userId) tx = DB.Where("user_id = ?", userId)
} else { } else {
tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) tx = DB.Where("user_id = ? and type = ?", userId, logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name like ?", modelName) tx = tx.Where("model_name like ?", modelName)
@@ -151,11 +149,14 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if endTimestamp != 0 { if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
err = tx.Model(&Log{}).Count(&total).Error err = tx.Model(&Log{}).Count(&total).Error
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
return logs, total, err
for i := range logs { for i := range logs {
var otherMap map[string]interface{} var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other) otherMap = common.StrToMap(logs[i].Other)
@@ -169,12 +170,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
} }
func SearchAllLogs(keyword string) (logs []*Log, err error) { func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
return logs, err return logs, err
} }
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err return logs, err
} }
@@ -185,10 +186,10 @@ type Stat struct {
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := LOG_DB.Table("logs").Select("sum(quota) quota") tx := DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询 // 为rpm和tpm创建单独的查询
rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
@@ -205,8 +206,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name like ?", modelName) tx = tx.Where("model_name = ?", modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName) rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName)
} }
if channel != 0 { if channel != 0 {
tx = tx.Where("channel_id = ?", channel) tx = tx.Where("channel_id = ?", channel)
@@ -227,7 +228,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
} }
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }
@@ -248,6 +249,6 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
} }
func DeleteOldLog(targetTimestamp int64) (int64, error) { func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }

View File

@@ -15,8 +15,6 @@ import (
var DB *gorm.DB var DB *gorm.DB
var LOG_DB *gorm.DB
func createRootAccountIfNeed() error { func createRootAccountIfNeed() error {
var user User var user User
//if user.Status != common.UserStatusEnabled { //if user.Status != common.UserStatusEnabled {
@@ -32,7 +30,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser, Role: common.RoleRootUser,
Status: common.UserStatusEnabled, Status: common.UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: nil, AccessToken: common.GetUUID(),
Quota: 100000000, Quota: 100000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@@ -40,9 +38,9 @@ func createRootAccountIfNeed() error {
return nil return nil
} }
func chooseDB(envName string) (*gorm.DB, error) { func chooseDB() (*gorm.DB, error) {
dsn := os.Getenv(envName) if os.Getenv("SQL_DSN") != "" {
if dsn != "" { dsn := os.Getenv("SQL_DSN")
if strings.HasPrefix(dsn, "postgres://") { if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL // Use PostgreSQL
common.SysLog("using PostgreSQL as database") common.SysLog("using PostgreSQL as database")
@@ -54,13 +52,6 @@ func chooseDB(envName string) (*gorm.DB, error) {
PrepareStmt: true, // precompile SQL PrepareStmt: true, // precompile SQL
}) })
} }
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL // Use MySQL
common.SysLog("using MySQL as database") common.SysLog("using MySQL as database")
// check parseTime // check parseTime
@@ -85,7 +76,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
} }
func InitDB() (err error) { func InitDB() (err error) {
db, err := chooseDB("SQL_DSN") db, err := chooseDB()
if err == nil { if err == nil {
if common.DebugEnabled { if common.DebugEnabled {
db = db.Debug() db = db.Debug()
@@ -109,44 +100,52 @@ func InitDB() (err error) {
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
//} //}
common.SysLog("database migration started") common.SysLog("database migration started")
err = migrateDB() err = db.AutoMigrate(&Channel{})
return err
} else {
common.FatalLog(err)
}
return err
}
func InitLogDB() (err error) {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
db, err := chooseDB("LOG_SQL_DSN")
if err == nil {
if common.DebugEnabled {
db = db.Debug()
}
LOG_DB = db
sqlDB, err := LOG_DB.DB()
if err != nil { if err != nil {
return err return err
} }
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100)) err = db.AutoMigrate(&Token{})
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000)) if err != nil {
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60))) return err
if !common.IsMasterNode {
return nil
} }
//if common.UsingMySQL { err = db.AutoMigrate(&User{})
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded if err != nil {
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded return err
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded }
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded err = db.AutoMigrate(&Option{})
//} if err != nil {
common.SysLog("database migration started") return err
err = migrateLOGDB() }
err = db.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return err
}
err = db.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = db.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = db.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
err = db.AutoMigrate(&Task{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err return err
} else { } else {
common.FatalLog(err) common.FatalLog(err)
@@ -154,66 +153,8 @@ func InitLogDB() (err error) {
return err return err
} }
func migrateDB() error { func CloseDB() error {
err := DB.AutoMigrate(&Channel{}) sqlDB, err := DB.DB()
if err != nil {
return err
}
err = DB.AutoMigrate(&Token{})
if err != nil {
return err
}
err = DB.AutoMigrate(&User{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Option{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Log{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = DB.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = DB.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Task{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
func closeDB(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil { if err != nil {
return err return err
} }
@@ -221,16 +162,6 @@ func closeDB(db *gorm.DB) error {
return err return err
} }
func CloseDB() error {
if LOG_DB != DB {
err := closeDB(LOG_DB)
if err != nil {
return err
}
}
return closeDB(DB)
}
var ( var (
lastPingTime time.Time lastPingTime time.Time
pingMutex sync.Mutex pingMutex sync.Mutex

View File

@@ -31,10 +31,12 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
@@ -60,17 +62,19 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = "" common.OptionMap["ServerAddress"] = ""
common.OptionMap["WorkerUrl"] = constant.WorkerUrl common.OptionMap["OutProxyUrl"] = ""
common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey common.OptionMap["StripeApiSecret"] = common.StripeApiSecret
common.OptionMap["PayAddress"] = "" common.OptionMap["StripeWebhookSecret"] = common.StripeWebhookSecret
common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["StripePriceId"] = common.StripePriceId
common.OptionMap["EpayId"] = "" common.OptionMap["PaymentEnabled"] = strconv.FormatBool(common.PaymentEnabled)
common.OptionMap["EpayKey"] = "" common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64)
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64) common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = "" common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["LinuxDoClientId"] = ""
common.OptionMap["LinuxDoClientSecret"] = ""
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
common.OptionMap["TelegramBotToken"] = "" common.OptionMap["TelegramBotToken"] = ""
common.OptionMap["TelegramBotName"] = "" common.OptionMap["TelegramBotName"] = ""
common.OptionMap["WeChatServerAddress"] = "" common.OptionMap["WeChatServerAddress"] = ""
@@ -86,7 +90,6 @@ func InitOptionMap() {
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["ChatLink"] = common.ChatLink
@@ -174,6 +177,8 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailVerificationEnabled = boolValue common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue common.GitHubOAuthEnabled = boolValue
case "LinuxDoOAuthEnabled":
common.LinuxDoOAuthEnabled = boolValue
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue common.WeChatAuthEnabled = boolValue
case "TelegramOAuthEnabled": case "TelegramOAuthEnabled":
@@ -182,6 +187,8 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileCheckEnabled = boolValue common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled": case "RegisterEnabled":
common.RegisterEnabled = boolValue common.RegisterEnabled = boolValue
case "UserSelfDeletionEnabled":
common.UserSelfDeletionEnabled = boolValue
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue common.EmailDomainRestrictionEnabled = boolValue
case "EmailAliasRestrictionEnabled": case "EmailAliasRestrictionEnabled":
@@ -241,29 +248,33 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPToken": case "SMTPToken":
common.SMTPToken = value common.SMTPToken = value
case "ServerAddress": case "ServerAddress":
constant.ServerAddress = value common.ServerAddress = value
case "WorkerUrl": case "OutProxyUrl":
constant.WorkerUrl = value common.OutProxyUrl = value
case "WorkerValidKey": case "StripeApiSecret":
constant.WorkerValidKey = value common.StripeApiSecret = value
case "PayAddress": case "StripeWebhookSecret":
constant.PayAddress = value common.StripeWebhookSecret = value
case "CustomCallbackAddress": case "StripePriceId":
constant.CustomCallbackAddress = value common.StripePriceId = value
case "EpayId": case "PaymentEnabled":
constant.EpayId = value common.PaymentEnabled, _ = strconv.ParseBool(value)
case "EpayKey": case "StripeUnitPrice":
constant.EpayKey = value common.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
case "Price":
constant.Price, _ = strconv.ParseFloat(value, 64)
case "MinTopUp": case "MinTopUp":
constant.MinTopUp, _ = strconv.Atoi(value) common.MinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio": case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value) err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId": case "GitHubClientId":
common.GitHubClientId = value common.GitHubClientId = value
case "GitHubClientSecret": case "GitHubClientSecret":
common.GitHubClientSecret = value common.GitHubClientSecret = value
case "LinuxDoClientId":
common.LinuxDoClientId = value
case "LinuxDoClientSecret":
common.LinuxDoClientSecret = value
case "LinuxDoMinLevel":
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
case "Footer": case "Footer":
common.Footer = value common.Footer = value
case "SystemName": case "SystemName":
@@ -304,8 +315,6 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = common.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = common.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio": case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value) err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice": case "ModelPrice":

View File

@@ -7,13 +7,14 @@ import (
) )
type Pricing struct { type Pricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"` ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"` QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"` ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"` ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"` OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"` CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_groups,omitempty"` EnableGroup []string `json:"enable_group,omitempty"`
} }
var ( var (
@@ -22,47 +23,40 @@ var (
updatePricingLock sync.Mutex updatePricingLock sync.Mutex
) )
func GetPricing() []Pricing { func GetPricing(group string) []Pricing {
updatePricingLock.Lock() updatePricingLock.Lock()
defer updatePricingLock.Unlock() defer updatePricingLock.Unlock()
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing() updatePricing()
} }
//if group != "" { if group != "" {
// userPricingMap := make([]Pricing, 0) userPricingMap := make([]Pricing, 0)
// models := GetGroupModels(group) models := GetGroupModels(group)
// for _, pricing := range pricingMap { for _, pricing := range pricingMap {
// if !common.StringsContains(models, pricing.ModelName) { if !common.StringsContains(models, pricing.ModelName) {
// pricing.Available = false pricing.Available = false
// } }
// userPricingMap = append(userPricingMap, pricing) userPricingMap = append(userPricingMap, pricing)
// } }
// return userPricingMap return userPricingMap
//} }
return pricingMap return pricingMap
} }
func updatePricing() { func updatePricing() {
//modelRatios := common.GetModelRatios() //modelRatios := common.GetModelRatios()
enableAbilities := GetAllEnableAbilities() enabledModels := GetEnabledModels()
modelGroupsMap := make(map[string][]string) allModels := make(map[string]int)
for _, ability := range enableAbilities { for i, model := range enabledModels {
groups := modelGroupsMap[ability.Model] allModels[model] = i
if groups == nil {
groups = make([]string, 0)
}
if !common.StringsContains(groups, ability.Group) {
groups = append(groups, ability.Group)
}
modelGroupsMap[ability.Model] = groups
} }
pricingMap = make([]Pricing, 0) pricingMap = make([]Pricing, 0)
for model, groups := range modelGroupsMap { for model, _ := range allModels {
pricing := Pricing{ pricing := Pricing{
ModelName: model, Available: true,
EnableGroup: groups, ModelName: model,
} }
modelPrice, findPrice := common.GetModelPrice(model, false) modelPrice, findPrice := common.GetModelPrice(model, false)
if findPrice { if findPrice {

View File

@@ -5,8 +5,6 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common"
"strconv" "strconv"
"strings" "strings"
) )
@@ -24,34 +22,10 @@ type Token struct {
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"` ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"` ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Group string `json:"group" gorm:"default:''"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
} }
func (token *Token) GetIpLimitsMap() map[string]any {
// delete empty spaces
//split with \n
ipLimitsMap := make(map[string]any)
if token.AllowIps == nil {
return ipLimitsMap
}
cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
if cleanIps == "" {
return ipLimitsMap
}
ips := strings.Split(cleanIps, "\n")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
ip = strings.ReplaceAll(ip, ",", "")
if common.IsIP(ip) {
ipLimitsMap[ip] = true
}
}
return ipLimitsMap
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
var tokens []*Token var tokens []*Token
var err error var err error
@@ -155,8 +129,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values // Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error { func (token *Token) Update() error {
var err error var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
return err return err
} }
@@ -258,56 +231,51 @@ func decreaseTokenQuota(id int, quota int) (err error) {
return err return err
} }
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) { func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
if quota < 0 { if quota < 0 {
return 0, errors.New("quota 不能为负数!") return 0, errors.New("quota 不能为负数!")
} }
if !relayInfo.IsPlayground { token, err := GetTokenById(tokenId)
token, err := GetTokenById(relayInfo.TokenId) if err != nil {
if err != nil { return 0, err
return 0, err
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
} }
userQuota, err = GetUserQuota(relayInfo.UserId) if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
userQuota, err = GetUserQuota(token.UserId)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if userQuota < quota { if userQuota < quota {
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
} }
if !relayInfo.IsPlayground { err = DecreaseTokenQuota(tokenId, quota)
err = DecreaseTokenQuota(relayInfo.TokenId, quota) if err != nil {
if err != nil { return 0, err
return 0, err
}
} }
err = DecreaseUserQuota(relayInfo.UserId, quota) err = DecreaseUserQuota(token.UserId, quota)
return userQuota - quota, err return userQuota - quota, err
} }
func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 { if quota > 0 {
err = DecreaseUserQuota(relayInfo.UserId, quota) err = DecreaseUserQuota(token.UserId, quota)
} else { } else {
err = IncreaseUserQuota(relayInfo.UserId, -quota) err = IncreaseUserQuota(token.UserId, -quota)
} }
if err != nil { if err != nil {
return err return err
} }
if !relayInfo.IsPlayground { if quota > 0 {
if quota > 0 { err = DecreaseTokenQuota(tokenId, quota)
err = DecreaseTokenQuota(relayInfo.TokenId, quota) } else {
} else { err = IncreaseTokenQuota(tokenId, -quota)
err = IncreaseTokenQuota(relayInfo.TokenId, -quota) }
} if err != nil {
if err != nil { return err
return err
}
} }
if sendEmail { if sendEmail {
@@ -316,7 +284,7 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0 noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota { if quotaTooLow || noMoreQuota {
go func() { go func() {
email, err := GetUserEmail(relayInfo.UserId) email, err := GetUserEmail(token.UserId)
if err != nil { if err != nil {
common.SysError("failed to fetch user email: " + err.Error()) common.SysError("failed to fetch user email: " + err.Error())
} }
@@ -325,7 +293,7 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
prompt = "您的额度已用尽" prompt = "您的额度已用尽"
} }
if email != "" { if email != "" {
topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
err = common.SendEmail(prompt, email, err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil { if err != nil {

View File

@@ -1,13 +1,21 @@
package model package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
)
type TopUp struct { type TopUp struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id" gorm:"index"`
Amount int `json:"amount"` Amount int `json:"amount"`
Money float64 `json:"money"` Money float64 `json:"money"`
TradeNo string `json:"trade_no"` TradeNo string `json:"trade_no" gorm:"unique"`
CreateTime int64 `json:"create_time"` CreateTime int64 `json:"create_time"`
Status string `json:"status"` CompleteTime int64 `json:"complete_time"`
Status string `json:"status"`
} }
func (topUp *TopUp) Insert() error { func (topUp *TopUp) Insert() error {
@@ -41,3 +49,51 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
} }
return topUp return topUp
} }
func Recharge(referenceId string, customerId string) (err error) {
if referenceId == "" {
return errors.New("未提供支付单号")
}
var quota float64
topUp := &TopUp{}
refCol := "`trade_no`"
if common.UsingPostgreSQL {
refCol = `"trade_no"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
if err != nil {
return errors.New("充值订单不存在")
}
if topUp.Status != common.TopUpStatusPending {
return errors.New("充值订单状态错误")
}
topUp.CompleteTime = common.GetTimestamp()
topUp.Status = common.TopUpStatusSuccess
err = tx.Save(topUp).Error
if err != nil {
return err
}
quota = topUp.Money * common.QuotaPerUnit
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
if err != nil {
return err
}
return nil
})
if err != nil {
return errors.New("充值失败," + err.Error())
}
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%d", common.LogQuotaF(quota), topUp.Amount))
return nil
}

View File

@@ -22,10 +22,12 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"` Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"` GitHubId string `json:"github_id" gorm:"column:github_id;index"`
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"` Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
@@ -35,20 +37,10 @@ type User struct {
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度 AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度 AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
StripeCustomer string `json:"stripe_customer" gorm:"column:stripe_customer;index"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
} }
func (user *User) GetAccessToken() string {
if user.AccessToken == nil {
return ""
}
return *user.AccessToken
}
func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) { func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User var user User
@@ -75,7 +67,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
func GetMaxUserId() int { func GetMaxUserId() int {
var user User var user User
DB.Last(&user) DB.Unscoped().Last(&user)
return user.Id return user.Id
} }
@@ -130,6 +122,20 @@ func GetUserById(id int, selectAll bool) (*User, error) {
return &user, err return &user, err
} }
func GetUserByIdUnscoped(id int, selectAll bool) (*User, error) {
if id == 0 {
return nil, errors.New("id 为空!")
}
user := User{Id: id}
var err error = nil
if selectAll {
err = DB.Unscoped().First(&user, "id = ?", id).Error
} else {
err = DB.Unscoped().Omit("password").First(&user, "id = ?", id).Error
}
return &user, err
}
func GetUserIdByAffCode(affCode string) (int, error) { func GetUserIdByAffCode(affCode string) (int, error) {
if affCode == "" { if affCode == "" {
return 0, errors.New("affCode 为空!") return 0, errors.New("affCode 为空!")
@@ -212,7 +218,7 @@ func (user *User) Insert(inviterId int) error {
} }
} }
user.Quota = common.QuotaForNewUser user.Quota = common.QuotaForNewUser
//user.SetAccessToken(common.GetUUID()) user.AccessToken = common.GetUUID()
user.AffCode = common.GetRandomString(4) user.AffCode = common.GetRandomString(4)
result := DB.Create(user) result := DB.Create(user)
if result.Error != nil { if result.Error != nil {
@@ -306,12 +312,11 @@ func (user *User) ValidateAndFill() (err error) {
// that means if your fields value is 0, '', false or other zero values, // that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions // it wont be used to build query conditions
password := user.Password password := user.Password
username := strings.TrimSpace(user.Username) if user.Username == "" || password == "" {
if username == "" || password == "" {
return errors.New("用户名或密码为空") return errors.New("用户名或密码为空")
} }
// find buy username or email // find buy username or email
DB.Where("username = ? OR email = ?", username, username).First(user) DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password) okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled { if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁") return errors.New("用户名或密码错误,或用户已被封禁")
@@ -343,6 +348,14 @@ func (user *User) FillUserByGitHubId() error {
return nil return nil
} }
func (user *User) FillUserByLinuxDoId() error {
if user.LinuxDoId == "" {
return errors.New("LINUX DO id 为空!")
}
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error { func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" { if user.WeChatId == "" {
return errors.New("WeChat id 为空!") return errors.New("WeChat id 为空!")
@@ -351,6 +364,14 @@ func (user *User) FillUserByWeChatId() error {
return nil return nil
} }
func (user *User) FillUserByUsername() error {
if user.Username == "" {
return errors.New("username 为空!")
}
DB.Where(User{Username: user.Username}).First(user)
return nil
}
func (user *User) FillUserByTelegramId() error { func (user *User) FillUserByTelegramId() error {
if user.TelegramId == "" { if user.TelegramId == "" {
return errors.New("Telegram id 为空!") return errors.New("Telegram id 为空!")
@@ -363,19 +384,27 @@ func (user *User) FillUserByTelegramId() error {
} }
func IsEmailAlreadyTaken(email string) bool { func IsEmailAlreadyTaken(email string) bool {
return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1 return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
} }
func IsWeChatIdAlreadyTaken(wechatId string) bool { func IsWeChatIdAlreadyTaken(wechatId string) bool {
return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
} }
func IsGitHubIdAlreadyTaken(githubId string) bool { func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
return DB.Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
} }
func IsTelegramIdAlreadyTaken(telegramId string) bool { func IsTelegramIdAlreadyTaken(telegramId string) bool {
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1 return DB.Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
} }
func ResetUserPasswordByEmail(email string, password string) error { func ResetUserPasswordByEmail(email string, password string) error {
@@ -415,6 +444,18 @@ func IsUserEnabled(userId int) (bool, error) {
return user.Status == common.UserStatusEnabled, nil return user.Status == common.UserStatusEnabled, nil
} }
func IsLinuxDoEnabled(userId int) (bool, error) {
if userId == 0 {
return false, errors.New("user id is empty")
}
var user User
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
if err != nil {
return false, err
}
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
}
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
if token == "" { if token == "" {
return nil return nil

View File

@@ -105,7 +105,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
for _, data := range response.Output.Results { for _, data := range response.Output.Results {
var b64Json string var b64Json string
if responseFormat == "b64_json" { if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url) _, b64, err := common.GetImageFromUrl(data.Url)
if err != nil { if err != nil {
common.LogError(c, "get_image_data_failed: "+err.Error()) common.LogError(c, "get_image_data_failed: "+err.Error())
continue continue

View File

@@ -8,6 +8,7 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"strings"
) )
const ( const (
@@ -30,7 +31,11 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.RequestMode = RequestModeMessage if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
} else {
a.RequestMode = RequestModeCompletion
}
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -48,8 +53,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
var claudeReq *claude.ClaudeRequest var claudeReq *claude.ClaudeRequest
var err error var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) if a.RequestMode == RequestModeCompletion {
claudeReq = claude.RequestOpenAI2ClaudeComplete(*request)
} else {
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
}
c.Set("request_model", request.Model) c.Set("request_model", request.Model)
c.Set("converted_request", claudeReq) c.Set("converted_request", claudeReq)
return claudeReq, err return claudeReq, err

View File

@@ -79,9 +79,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
} else { } else {
err, usage = ClaudeHandler(c, resp, a.RequestMode, info) err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

View File

@@ -31,9 +31,9 @@ type ClaudeMessage struct {
} }
type Tool struct { type Tool struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
InputSchema map[string]interface{} `json:"input_schema"` InputSchema InputSchema `json:"input_schema"`
} }
type InputSchema struct { type InputSchema struct {

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -11,8 +12,6 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
func stopReasonClaude2OpenAI(reason string) string { func stopReasonClaude2OpenAI(reason string) string {
@@ -64,21 +63,15 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok { if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTool := Tool{ claudeTools = append(claudeTools, Tool{
Name: tool.Function.Name, Name: tool.Function.Name,
Description: tool.Function.Description, Description: tool.Function.Description,
} InputSchema: InputSchema{
claudeTool.InputSchema = make(map[string]interface{}) Type: params["type"].(string),
claudeTool.InputSchema["type"] = params["type"].(string) Properties: params["properties"],
claudeTool.InputSchema["properties"] = params["properties"] Required: params["required"],
claudeTool.InputSchema["required"] = params["required"] },
for s, a := range params { })
if s == "type" || s == "properties" || s == "required" {
continue
}
claudeTool.InputSchema[s] = a
}
claudeTools = append(claudeTools, claudeTool)
} }
} }
@@ -109,10 +102,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
} }
} }
formatMessages := make([]dto.Message, 0) formatMessages := make([]dto.Message, 0)
lastMessage := dto.Message{ var lastMessage *dto.Message
Role: "tool",
}
for i, message := range textRequest.Messages { for i, message := range textRequest.Messages {
//if message.Role == "system" {
// if i != 0 {
// message.Role = "user"
// }
//}
if message.Role == "" { if message.Role == "" {
textRequest.Messages[i].Role = "user" textRequest.Messages[i].Role = "user"
} }
@@ -120,13 +116,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.Content,
} }
if message.Role == "tool" { if lastMessage != nil && lastMessage.Role == message.Role {
fmtMessage.ToolCallId = message.ToolCallId
}
if message.Role == "assistant" && message.ToolCalls != nil {
fmtMessage.ToolCalls = message.ToolCalls
}
if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
if lastMessage.IsStringContent() && message.IsStringContent() { if lastMessage.IsStringContent() && message.IsStringContent() {
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\"")) content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
fmtMessage.Content = content fmtMessage.Content = content
@@ -139,11 +129,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
fmtMessage.Content = content fmtMessage.Content = content
} }
formatMessages = append(formatMessages, fmtMessage) formatMessages = append(formatMessages, fmtMessage)
lastMessage = fmtMessage lastMessage = &textRequest.Messages[i]
} }
claudeMessages := make([]ClaudeMessage, 0) claudeMessages := make([]ClaudeMessage, 0)
isFirstMessage := true
for _, message := range formatMessages { for _, message := range formatMessages {
if message.Role == "system" { if message.Role == "system" {
if message.IsStringContent() { if message.IsStringContent() {
@@ -159,54 +148,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
claudeRequest.System = content claudeRequest.System = content
} }
} else { } else {
if isFirstMessage {
isFirstMessage = false
if message.Role != "user" {
// fix: first message is assistant, add user message
claudeMessage := ClaudeMessage{
Role: "user",
Content: []ClaudeMediaMessage{
{
Type: "text",
Text: "...",
},
},
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeMessage := ClaudeMessage{ claudeMessage := ClaudeMessage{
Role: message.Role, Role: message.Role,
} }
if message.Role == "tool" { if message.IsStringContent() {
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
lastMessage := claudeMessages[len(claudeMessages)-1]
if content, ok := lastMessage.Content.(string); ok {
lastMessage.Content = []ClaudeMediaMessage{
{
Type: "text",
Text: content,
},
}
}
lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
})
claudeMessages[len(claudeMessages)-1] = lastMessage
continue
} else {
claudeMessage.Role = "user"
claudeMessage.Content = []ClaudeMediaMessage{
{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
},
}
}
} else if message.IsStringContent() && message.ToolCalls == nil {
claudeMessage.Content = message.StringContent() claudeMessage.Content = message.StringContent()
} else { } else {
claudeMediaMessages := make([]ClaudeMediaMessage, 0) claudeMediaMessages := make([]ClaudeMediaMessage, 0)
@@ -225,11 +170,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
// 判断是否是url // 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据 // 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url) mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data claudeMediaMessage.Source.Data = data
} else { } else {
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url) _, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -239,28 +184,6 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
} }
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
} }
if message.ToolCalls != nil {
for _, tc := range message.ToolCalls.([]interface{}) {
toolCallJSON, _ := json.Marshal(tc)
var toolCall dto.ToolCall
err := json.Unmarshal(toolCallJSON, &toolCall)
if err != nil {
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
continue
}
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
Type: "tool_use",
Id: toolCall.ID,
Name: toolCall.Function.Name,
Input: inputObj,
})
}
}
claudeMessage.Content = claudeMediaMessages claudeMessage.Content = claudeMediaMessages
} }
claudeMessages = append(claudeMessages, claudeMessage) claudeMessages = append(claudeMessages, claudeMessage)
@@ -395,13 +318,12 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(tools) > 0 { if len(tools) > 0 {
choice.Message.ToolCalls = tools choice.Message.ToolCalls = tools
} }
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice) choices = append(choices, choice)
fullTextResponse.Choices = choices fullTextResponse.Choices = choices
return &fullTextResponse return &fullTextResponse
} }
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage var usage *dto.Usage
usage = &dto.Usage{} usage = &dto.Usage{}
@@ -483,7 +405,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
return nil, usage return nil, usage
} }
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -509,15 +431,15 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
}, nil }, nil
} }
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName) completionTokens, err := service.CountTokenText(claudeResponse.Completion, model)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
} }
usage := dto.Usage{} usage := dto.Usage{}
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
usage.PromptTokens = info.PromptTokens usage.PromptTokens = promptTokens
usage.CompletionTokens = completionTokens usage.CompletionTokens = completionTokens
usage.TotalTokens = info.PromptTokens + completionTokens usage.TotalTokens = promptTokens + completionTokens
} else { } else {
usage.PromptTokens = claudeResponse.Usage.InputTokens usage.PromptTokens = claudeResponse.Usage.InputTokens
usage.CompletionTokens = claudeResponse.Usage.OutputTokens usage.CompletionTokens = claudeResponse.Usage.OutputTokens

View File

@@ -1,10 +1,7 @@
package cohere package cohere
var ModelList = []string{ var ModelList = []string{
"command-r", "command-r-plus", "command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
"command-r-08-2024", "command-r-plus-08-2024",
"c4ai-aya-23-35b", "c4ai-aya-23-8b",
"command-light", "command-light-nightly", "command", "command-nightly",
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0", "rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
} }

View File

@@ -8,7 +8,6 @@ type CohereRequest struct {
Message string `json:"message"` Message string `json:"message"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
SafetyMode string `json:"safety_mode,omitempty"`
} }
type ChatHistory struct { type ChatHistory struct {

View File

@@ -23,9 +23,6 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
Stream: textRequest.Stream, Stream: textRequest.Stream,
MaxTokens: textRequest.GetMaxTokens(), MaxTokens: textRequest.GetMaxTokens(),
} }
if common.CohereSafetySetting != "NONE" {
cohereReq.SafetyMode = common.CohereSafetySetting
}
if cohereReq.MaxTokens == 0 { if cohereReq.MaxTokens == 0 {
cohereReq.MaxTokens = 4000 cohereReq.MaxTokens = 4000
} }
@@ -47,7 +44,6 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
}) })
} }
} }
return &cohereReq return &cohereReq
} }

View File

@@ -70,9 +70,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info) err, usage = geminiChatStreamHandler(c, resp, info)
} else { } else {
err, usage = GeminiChatHandler(c, resp) err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

View File

@@ -6,7 +6,7 @@ const (
var ModelList = []string{ var ModelList = []string{
"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra", "gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827", "gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001",
} }
var ChannelName = "google gemini" var ChannelName = "google gemini"

View File

@@ -86,7 +86,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
// 判断是否是url // 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据 // 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &GeminiInlineData{
MimeType: mimeType, MimeType: mimeType,
@@ -94,7 +94,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
}, },
}) })
} else { } else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) _, format, base64String, err := common.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil { if err != nil {
continue continue
} }
@@ -220,7 +220,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response return &response
} }
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := "" responseText := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp() createAt := common.GetTimestamp()
@@ -279,7 +279,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
return nil, usage return nil, usage
} }
func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -32,7 +32,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings { } else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
} }
return "", errors.New("invalid relay mode") return "", errors.New("invalid relay mode")
} }
@@ -58,8 +58,6 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp) err, usage = jinaRerankHandler(c, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = jinaEmbeddingHandler(c, resp)
} }
return return
} }

View File

@@ -33,28 +33,3 @@ func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage return nil, &jinaResp.Usage
} }
func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var jinaResp dto.OpenAIEmbeddingResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}

View File

@@ -3,39 +3,25 @@ package ollama
import "one-api/dto" import "one-api/dto"
type OllamaRequest struct { type OllamaRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"` Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"` Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"` Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"` Stop any `json:"stop,omitempty"`
Tools []dto.ToolCall `json:"tools,omitempty"` Tools []dto.ToolCall `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"` ResponseFormat *dto.ResponseFormat `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
} }
type OllamaEmbeddingRequest struct { type OllamaEmbeddingRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Input []string `json:"input"` Prompt any `json:"prompt,omitempty"`
Options *Options `json:"options,omitempty"`
} }
type OllamaEmbeddingResponse struct { type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding []float64 `json:"embedding,omitempty"` Embedding []float64 `json:"embedding,omitempty"`
} }

View File

@@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/service" "one-api/service"
"strings"
) )
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
@@ -44,15 +45,8 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{ return &OllamaEmbeddingRequest{
Model: request.Model, Model: request.Model,
Input: request.ParseInput(), Prompt: strings.Join(request.ParseInput(), " "),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
} }
} }
@@ -70,9 +64,6 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if ollamaEmbeddingResponse.Error != "" {
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
}
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{ data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: ollamaEmbeddingResponse.Embedding, Embedding: ollamaEmbeddingResponse.Embedding,

View File

@@ -78,12 +78,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if info.ChannelType != common.ChannelTypeOpenAI { if info.ChannelType != common.ChannelTypeOpenAI {
request.StreamOptions = nil request.StreamOptions = nil
} }
if strings.HasPrefix(request.Model, "o1-") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
}
return request, nil return request, nil
} }

View File

@@ -1,24 +1,21 @@
package openai package openai
var ModelList = []string{ var ModelList = []string{
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-instruct", "gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0613", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"chatgpt-4o-latest", "gpt-4o", "gpt-4o-2024-05-13",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18", "gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
"text-moderation-latest", "text-moderation-stable", "text-moderation-latest", "text-moderation-stable",
"text-davinci-edit-001", "text-davinci-edit-001",
"davinci-002", "babbage-002", "davinci-002", "babbage-002",
"dall-e-3", "dall-e-2", "dall-e-3",
"whisper-1", "whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
} }

View File

@@ -1,83 +0,0 @@
package siliconflow
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeRerank:
err, usage = siliconflowRerankHandler(c, resp)
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -1,51 +0,0 @@
package siliconflow
var ModelList = []string{
"THUDM/glm-4-9b-chat",
//"stabilityai/stable-diffusion-xl-base-1.0",
//"TencentARC/PhotoMaker",
"InstantX/InstantID",
//"stabilityai/stable-diffusion-2-1",
//"stabilityai/sd-turbo",
//"stabilityai/sdxl-turbo",
"ByteDance/SDXL-Lightning",
"deepseek-ai/deepseek-llm-67b-chat",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"deepseek-ai/DeepSeek-V2-Chat",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
//"stabilityai/stable-diffusion-3-medium",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-large-zh-v1.5",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/Qwen/Qwen1.5-7B-Chat",
"Pro/THUDM/glm-4-9b-chat",
"Pro/THUDM/chatglm3-6b",
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
"Pro/01-ai/Yi-1.5-6B-Chat",
"Pro/google/gemma-2-9b-it",
"Pro/internlm/internlm2_5-7b-chat",
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
"black-forest-labs/FLUX.1-schnell",
"iic/SenseVoiceSmall",
"netease-youdao/bce-embedding-base_v1",
"BAAI/bge-m3",
"internlm/internlm2_5-20b-chat",
"Qwen/Qwen2-Math-72B-Instruct",
"netease-youdao/bce-reranker-base_v1",
"BAAI/bge-reranker-v2-m3",
}
var ChannelName = "siliconflow"

View File

@@ -1,17 +0,0 @@
package siliconflow
import "one-api/dto"
type SFTokens struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type SFMeta struct {
Tokens SFTokens `json:"tokens"`
}
type SFRerankResponse struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta SFMeta `json:"meta"`
}

View File

@@ -1,44 +0,0 @@
package siliconflow
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
CompletionTokens: siliconflowResp.Meta.Tokens.OutputTokens,
TotalTokens: siliconflowResp.Meta.Tokens.InputTokens + siliconflowResp.Meta.Tokens.OutputTokens,
}
rerankResp := &dto.RerankResponse{
Results: siliconflowResp.Results,
Usage: *usage,
}
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, usage
}

View File

@@ -1,184 +0,0 @@
package vertex
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/claude"
"one-api/relay/channel/gemini"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"strings"
)
const (
RequestModeClaude = 1
RequestModeGemini = 2
RequestModeLlama = 3
)
var claudeModelMap = map[string]string{
"claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
"claude-3-opus-20240229": "claude-3-opus@20240229",
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
}
const anthropicVersion = "vertex-2023-10-16"
type Adaptor struct {
RequestMode int
AccountCredentials Credentials
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
a.RequestMode = RequestModeGemini
} else if strings.Contains(info.UpstreamModelName, "llama") {
a.RequestMode = RequestModeLlama
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
adc := &Credentials{}
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials file: %w", err)
}
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
if info.IsStream {
suffix = "streamGenerateContent?alt=sse"
} else {
suffix = "generateContent"
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
region,
adc.ProjectID,
region,
info.UpstreamModelName,
suffix,
), nil
} else if a.RequestMode == RequestModeClaude {
if info.IsStream {
suffix = "streamRawPredict?alt=sse"
} else {
suffix = "rawPredict"
}
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
info.UpstreamModelName = v
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
info.UpstreamModelName,
suffix,
), nil
} else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
region,
adc.ProjectID,
region,
), nil
}
return "", errors.New("unsupported request mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
accessToken, err := getAccessToken(a, info)
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeClaude {
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil {
return nil, err
}
vertexClaudeReq := &VertexAIClaudeRequest{
AnthropicVersion: anthropicVersion,
}
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
return nil, errors.New("failed to copy claude request")
}
c.Set("request_model", request.Model)
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
geminiRequest := gemini.CovertGemini2OpenAI(*request)
c.Set("request_model", request.Model)
return geminiRequest, nil
} else if a.RequestMode == RequestModeLlama {
return request, nil
}
return nil, errors.New("unsupported request mode")
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
case RequestModeLlama:
err, usage = openai.OaiStreamHandler(c, resp, info)
}
} else {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
case RequestModeGemini:
err, usage = gemini.GeminiChatHandler(c, resp)
case RequestModeLlama:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -1,15 +0,0 @@
package vertex
var ModelList = []string{
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
//"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
"meta/llama3-405b-instruct-maas",
}
var ChannelName = "vertex-ai"

View File

@@ -1,17 +0,0 @@
package vertex
import "one-api/relay/channel/claude"
type VertexAIClaudeRequest struct {
AnthropicVersion string `json:"anthropic_version"`
Messages []claude.ClaudeMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}

View File

@@ -1,16 +0,0 @@
package vertex
import "one-api/common"
func GetModelRegion(other string, localModelName string) string {
// if other is json string
if common.IsJsonStr(other) {
m := common.StrToMap(other)
if m[localModelName] != nil {
return m[localModelName].(string)
} else {
return m["default"].(string)
}
}
return other
}

View File

@@ -1,122 +0,0 @@
package vertex
import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"github.com/bytedance/gopkg/cache/asynccache"
"github.com/golang-jwt/jwt"
"net/http"
"net/url"
relaycommon "one-api/relay/common"
"strings"
"fmt"
"time"
)
type Credentials struct {
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
ClientID string `json:"client_id"`
}
var Cache = asynccache.NewAsyncCache(asynccache.Options{
RefreshDuration: time.Minute * 35,
EnableExpire: true,
ExpireDuration: time.Minute * 30,
Fetcher: func(key string) (interface{}, error) {
return nil, errors.New("not found")
},
})
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
val, err := Cache.Get(cacheKey)
if err == nil {
return val.(string), nil
}
signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
if err != nil {
return "", fmt.Errorf("failed to create signed JWT: %w", err)
}
newToken, err := exchangeJwtForAccessToken(signedJWT)
if err != nil {
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
}
if err := Cache.SetDefault(cacheKey, newToken); err {
return newToken, nil
}
return newToken, nil
}
func createSignedJWT(email, privateKeyPEM string) (string, error) {
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
if block == nil {
return "", fmt.Errorf("failed to parse PEM block containing the private key")
}
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", err
}
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return "", fmt.Errorf("not an RSA private key")
}
now := time.Now()
claims := jwt.MapClaims{
"iss": email,
"scope": "https://www.googleapis.com/auth/cloud-platform",
"aud": "https://www.googleapis.com/oauth2/v4/token",
"exp": now.Add(time.Minute * 35).Unix(),
"iat": now.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedToken, err := token.SignedString(rsaPrivateKey)
if err != nil {
return "", err
}
return signedToken, nil
}
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
data := url.Values{}
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
data.Set("assertion", signedJWT)
resp, err := http.PostForm(authURL, data)
if err != nil {
return "", err
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
if accessToken, ok := result["access_token"].(string); ok {
return accessToken, nil
}
return "", fmt.Errorf("failed to get access token: %v", result)
}

View File

@@ -1,7 +1,7 @@
package zhipu_4v package zhipu_4v
var ModelList = []string{ var ModelList = []string{
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", "glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flash", "glm-4v-plus", "glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools",
} }
var ChannelName = "zhipu_4v" var ChannelName = "zhipu_4v"

View File

@@ -20,10 +20,8 @@ type RelayInfo struct {
setFirstResponse bool setFirstResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
IsPlayground bool
RelayMode int RelayMode int
UpstreamModelName string UpstreamModelName string
OriginModelName string
RequestURLPath string RequestURLPath string
ApiVersion string ApiVersion string
PromptTokens int PromptTokens int
@@ -59,27 +57,17 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
TokenUnlimited: tokenUnlimited, TokenUnlimited: tokenUnlimited,
StartTime: startTime, StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second), FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"),
ApiType: apiType, ApiType: apiType,
ApiVersion: c.GetString("api_version"), ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"), Organization: c.GetString("channel_organization"),
} }
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
info.RequestURLPath = "/v1" + info.RequestURLPath
}
if info.BaseUrl == "" { if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType] info.BaseUrl = common.ChannelBaseURLs[channelType]
} }
if info.ChannelType == common.ChannelTypeAzure { if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c) info.ApiVersion = GetAPIVersion(c)
} }
if info.ChannelType == common.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region")
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare { info.ChannelType == common.ChannelCloudflare {
@@ -152,20 +140,3 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
} }
return info return info
} }
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
return &RelayInfo{
ChannelType: info.ChannelType,
ChannelId: info.ChannelId,
TokenId: info.TokenId,
UserId: info.UserId,
Group: info.Group,
StartTime: info.StartTime,
ApiType: info.ApiType,
RelayMode: info.RelayMode,
UpstreamModelName: info.UpstreamModelName,
RequestURLPath: info.RequestURLPath,
ApiKey: info.ApiKey,
BaseUrl: info.BaseUrl,
}
}

View File

@@ -23,8 +23,6 @@ const (
APITypeDify APITypeDify
APITypeJina APITypeJina
APITypeCloudflare APITypeCloudflare
APITypeSiliconFlow
APITypeVertexAi
APITypeDummy // this one is only for count, do not add any channel after this APITypeDummy // this one is only for count, do not add any channel after this
) )
@@ -68,10 +66,6 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeJina apiType = APITypeJina
case common.ChannelCloudflare: case common.ChannelCloudflare:
apiType = APITypeCloudflare apiType = APITypeCloudflare
case common.ChannelTypeSiliconFlow:
apiType = APITypeSiliconFlow
case common.ChannelTypeVertexAi:
apiType = APITypeVertexAi
} }
if apiType == -1 { if apiType == -1 {
return APITypeOpenAI, false return APITypeOpenAI, false

View File

@@ -42,7 +42,7 @@ const (
func Path2RelayMode(path string) int { func Path2RelayMode(path string) int {
relayMode := RelayModeUnknown relayMode := RelayModeUnknown
if strings.HasPrefix(path, "/v1/chat/completions") || strings.HasPrefix(path, "/pg/chat/completions") { if strings.HasPrefix(path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") { } else if strings.HasPrefix(path, "/v1/completions") {
relayMode = RelayModeCompletions relayMode = RelayModeCompletions

View File

@@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil { if err != nil {
@@ -87,7 +87,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
preConsumedQuota = 0 preConsumedQuota = 0
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
@@ -126,7 +126,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
if resp != nil { if resp != nil {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp) openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
@@ -136,7 +136,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr

View File

@@ -38,7 +38,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.Model == "" { if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2" imageRequest.Model = "dall-e-2"
} }
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
// Not "256x256", "512x512", or "1024x1024" // Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
@@ -48,9 +50,6 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
} }
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
}
//if imageRequest.N != 1 { //if imageRequest.N != 1 {
// return nil, errors.New("n must be 1") // return nil, errors.New("n must be 1")
//} //}
@@ -126,7 +125,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
quota := int(imageRatio * groupRatio * common.QuotaPerUnit) quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
if userQuota-quota < 0 { if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("image pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, quota)), "insufficient_user_quota", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)

View File

@@ -12,7 +12,6 @@ import (
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strconv" "strconv"
@@ -31,7 +30,7 @@ func RelayMidjourneyImage(c *gin.Context) {
}) })
return return
} }
resp, err := http.Get(midjourneyTask.ImageUrl) resp, err := common.ProxiedHttpGet(midjourneyTask.ImageUrl, common.OutProxyUrl)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed", "error": "http_get_image_failed",
@@ -112,7 +111,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = "" midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled { if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" { if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
} }
@@ -147,7 +146,6 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
var swapFaceRequest dto.SwapFaceRequest var swapFaceRequest dto.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest) err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil { if err != nil {
@@ -193,7 +191,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
} }
defer func(ctx context.Context) { defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true) err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }
@@ -358,7 +356,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
consumeQuota := true consumeQuota := true
var midjRequest dto.MidjourneyRequest var midjRequest dto.MidjourneyRequest
err := common.UnmarshalBodyReusable(c, &midjRequest) err := common.UnmarshalBodyReusable(c, &midjRequest)
@@ -498,7 +495,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 { if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true) err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }

View File

@@ -52,7 +52,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
} }
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
case relayconstant.RelayModeModerations: case relayconstant.RelayModeModerations:
if textRequest.Input == "" || textRequest.Input == nil { if textRequest.Input == "" {
return nil, errors.New("field input is required") return nil, errors.New("field input is required")
} }
case relayconstant.RelayModeEdits: case relayconstant.RelayModeEdits:
@@ -178,7 +178,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if resp != nil { if resp != nil {
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp) openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
@@ -188,7 +188,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
@@ -205,6 +205,15 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model) promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
case relayconstant.RelayModeCompletions: case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
prompts := textRequest.Prompt
switch v := prompts.(type) {
case string:
prompts = v + textRequest.Suffix
case []string:
prompts = append(v, textRequest.Suffix)
}
promptTokens, err = service.CountTokenInput(prompts, textRequest.Model)
case relayconstant.RelayModeModerations: case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
@@ -238,12 +247,9 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if err != nil { if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota <= 0 { if userQuota <= 0 || userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
}
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil { if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
@@ -266,7 +272,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
} }
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota) userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
if err != nil { if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
@@ -274,11 +280,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return preConsumedQuota, userQuota, nil return preConsumedQuota, userQuota, nil
} }
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) { func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsumedQuota int) {
if preConsumedQuota != 0 { if preConsumedQuota != 0 {
go func(ctx context.Context) { go func(ctx context.Context) {
// return pre-consumed quota // return pre-consumed quota
err := model.PostConsumeTokenQuota(relayInfo, userQuota, -preConsumedQuota, 0, false) err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
if err != nil { if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error()) common.SysError("error return pre-consumed quota: " + err.Error())
} }
@@ -317,7 +323,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
totalTokens := promptTokens + completionTokens totalTokens := promptTokens + completionTokens
var logContent string var logContent string
if !usePrice { if !usePrice {
logContent = fmt.Sprintf("模型倍率 %.2f补全倍率 %.2f分组倍率 %.2f", modelRatio, completionRatio, groupRatio) logContent = fmt.Sprintf("模型倍率 %.2f分组倍率 %.2f补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
} else { } else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
} }
@@ -336,7 +342,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
//} //}
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 { if quotaDelta != 0 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil { if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error()) common.LogError(ctx, "error consuming token remain quota: "+err.Error())
} }
@@ -353,9 +359,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
if strings.HasPrefix(logModel, "gpt-4-gizmo") { if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*" logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", modelName) logContent += fmt.Sprintf(",模型 %s", modelName)
} } else if strings.HasPrefix(logModel, "g-") {
if strings.HasPrefix(logModel, "gpt-4o-gizmo") { logModel = "g-*"
logModel = "gpt-4o-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", modelName) logContent += fmt.Sprintf(",模型 %s", modelName)
} }
if extraContent != "" { if extraContent != "" {

View File

@@ -16,10 +16,8 @@ import (
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"one-api/relay/channel/palm" "one-api/relay/channel/palm"
"one-api/relay/channel/perplexity" "one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
"one-api/relay/channel/task/suno" "one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent" "one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
"one-api/relay/channel/xunfei" "one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_4v" "one-api/relay/channel/zhipu_4v"
@@ -64,10 +62,6 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &jina.Adaptor{} return &jina.Adaptor{}
case constant.APITypeCloudflare: case constant.APITypeCloudflare:
return &cloudflare.Adaptor{} return &cloudflare.Adaptor{}
case constant.APITypeSiliconFlow:
return &siliconflow.Adaptor{}
case constant.APITypeVertexAi:
return &vertex.Adaptor{}
} }
return nil return nil
} }

View File

@@ -38,23 +38,6 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
if len(rerankRequest.Documents) == 0 { if len(rerankRequest.Documents) == 0 {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
} }
// map model name
modelMapping := c.GetString("model_mapping")
//isModelMapped := false
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[rerankRequest.Model] != "" {
rerankRequest.Model = modelMap[rerankRequest.Model]
// set upstream model name
//isModelMapped = true
}
}
relayInfo.UpstreamModelName = rerankRequest.Model relayInfo.UpstreamModelName = rerankRequest.Model
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false) modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group) groupRatio := common.GetGroupRatio(relayInfo.Group)
@@ -101,7 +84,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
} }
if resp != nil { if resp != nil {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp) openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
@@ -111,7 +94,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr

View File

@@ -111,8 +111,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
defer func(ctx context.Context) { defer func(ctx context.Context) {
// release quota // release quota
if relayInfo.ConsumeQuota && taskErr == nil { if relayInfo.ConsumeQuota && taskErr == nil {
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true)
err := model.PostConsumeTokenQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }

View File

@@ -18,13 +18,14 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus) apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout) apiRouter.GET("/about", controller.GetAbout)
//apiRouter.GET("/midjourney", controller.GetMidjourney) apiRouter.GET("/midjourney", controller.GetMidjourney)
apiRouter.GET("/home_page_content", controller.GetHomePageContent) apiRouter.GET("/home_page_content", controller.GetHomePageContent)
apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing) apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxDoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
@@ -32,14 +33,14 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin) apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind) apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
userRoute := apiRouter.Group("/user") userRoute := apiRouter.Group("/user")
{ {
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login) userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog) //userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout) userRoute.GET("/logout", controller.Logout)
userRoute.GET("/epay/notify", controller.EpayNotify)
userRoute.GET("/groups", controller.GetUserGroups)
selfRoute := userRoute.Group("/") selfRoute := userRoute.Group("/")
selfRoute.Use(middleware.UserAuth()) selfRoute.Use(middleware.UserAuth())
@@ -50,8 +51,8 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.DELETE("/self", controller.DeleteSelf) selfRoute.DELETE("/self", controller.DeleteSelf)
selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode) selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp) selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
selfRoute.POST("/pay", controller.RequestEpay) selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestPayLink)
selfRoute.POST("/amount", controller.RequestAmount) selfRoute.POST("/amount", controller.RequestAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota) selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
} }
@@ -65,7 +66,7 @@ func SetApiRouter(router *gin.Engine) {
adminRoute.POST("/", controller.CreateUser) adminRoute.POST("/", controller.CreateUser)
adminRoute.POST("/manage", controller.ManageUser) adminRoute.POST("/manage", controller.ManageUser)
adminRoute.PUT("/", controller.UpdateUser) adminRoute.PUT("/", controller.UpdateUser)
adminRoute.DELETE("/:id", controller.DeleteUser) adminRoute.DELETE("/:id", controller.HardDeleteUser)
} }
} }
optionRoute := apiRouter.Group("/option") optionRoute := apiRouter.Group("/option")

View File

@@ -16,11 +16,6 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter.GET("", controller.ListModels) modelsRouter.GET("", controller.ListModels)
modelsRouter.GET("/:model", controller.RetrieveModel) modelsRouter.GET("/:model", controller.RetrieveModel)
} }
playgroundRouter := router.Group("/pg")
playgroundRouter.Use(middleware.UserAuth())
{
playgroundRouter.POST("/chat/completions", controller.Playground)
}
relayV1Router := router.Group("/v1") relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{ {

View File

@@ -54,8 +54,6 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
switch err.Error.Type { switch err.Error.Type {
case "insufficient_quota": case "insufficient_quota":
return true return true
case "insufficient_user_quota":
return true
// https://docs.anthropic.com/claude/reference/errors // https://docs.anthropic.com/claude/reference/errors
case "authentication_error": case "authentication_error":
return true return true
@@ -73,15 +71,6 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
} else if strings.HasPrefix(err.Error.Message, "Permission denied") { } else if strings.HasPrefix(err.Error.Message, "Permission denied") {
return true return true
} }
if strings.Contains(err.Error.Message, "The security token included in the request is invalid") { // anthropic
return true
} else if strings.Contains(err.Error.Message, "Operation not allowed") {
return true
} else if strings.Contains(err.Error.Message, "Your account is not authorized") {
return true
}
return false return false
} }

View File

@@ -1,12 +0,0 @@
package service
import (
"one-api/constant"
)
func GetCallbackAddress() string {
if constant.CustomCallbackAddress == "" {
return constant.ServerAddress
}
return constant.CustomCallbackAddress
}

View File

@@ -28,11 +28,13 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode // OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error() text := err.Error()
lowerText := strings.ToLower(text) // 定义一个正则表达式匹配URL
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
common.SysLog(fmt.Sprintf("error: %s", text)) common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败" text = "请求上游地址失败"
} }
//避免暴露内部错误
openAIError := dto.OpenAIError{ openAIError := dto.OpenAIError{
Message: text, Message: text,
Type: "new_api_error", Type: "new_api_error",
@@ -111,12 +113,14 @@ func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskErro
func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
text := err.Error() text := err.Error()
lowerText := strings.ToLower(text)
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { // 定义一个正则表达式匹配URL
if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
common.SysLog(fmt.Sprintf("error: %s", text)) common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败" text = "请求上游地址失败"
} }
//避免暴露内部错误 //避免暴露内部错误
taskError := &dto.TaskError{ taskError := &dto.TaskError{
Code: code, Code: code,
Message: text, Message: text,

View File

@@ -18,7 +18,6 @@ import (
// tokenEncoderMap won't grow after initialization // tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken var defaultTokenEncoder *tiktoken.Tiktoken
var cl200kTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() { func InitTokenEncoders() {
common.SysLog("initializing token encoders") common.SysLog("initializing token encoders")
@@ -31,19 +30,18 @@ func InitTokenEncoders() {
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
} }
cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
} }
for model, _ := range common.GetDefaultModelRatioMap() { for model, _ := range common.GetDefaultModelRatioMap() {
if strings.HasPrefix(model, "gpt-3.5") { if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4o") {
tokenEncoderMap[model] = gpt4oTokenEncoder
} else if strings.HasPrefix(model, "gpt-4") { } else if strings.HasPrefix(model, "gpt-4") {
if strings.HasPrefix(model, "gpt-4o") { tokenEncoderMap[model] = gpt4TokenEncoder
tokenEncoderMap[model] = cl200kTokenEncoder
} else {
tokenEncoderMap[model] = gpt4TokenEncoder
}
} else { } else {
tokenEncoderMap[model] = nil tokenEncoderMap[model] = nil
} }
@@ -51,13 +49,6 @@ func InitTokenEncoders() {
common.SysLog("token encoders initialized") common.SysLog("token encoders initialized")
} }
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") {
return cl200kTokenEncoder
}
return defaultTokenEncoder
}
func getTokenEncoder(model string) *tiktoken.Tiktoken { func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, ok := tokenEncoderMap[model] tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil { if ok && tokenEncoder != nil {
@@ -68,13 +59,12 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, err := tiktoken.EncodingForModel(model) tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = getModelDefaultTokenEncoder(model) tokenEncoder = defaultTokenEncoder
} }
tokenEncoderMap[model] = tokenEncoder tokenEncoderMap[model] = tokenEncoder
return tokenEncoder return tokenEncoder
} }
// 如果model不在tokenEncoderMap中直接返回默认的tokenEncoder return defaultTokenEncoder
return getModelDefaultTokenEncoder(model)
} }
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
@@ -111,10 +101,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
var err error var err error
var format string var format string
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
config, format, err = DecodeUrlImageData(imageUrl.Url) common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
} else { } else {
common.SysLog(fmt.Sprintf("decoding image")) common.SysLog(fmt.Sprintf("decoding image"))
config, format, _, err = DecodeBase64ImageData(imageUrl.Url) config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
} }
if err != nil { if err != nil {
return 0, err return 0, err

View File

@@ -1,26 +0,0 @@
package service
import (
"bytes"
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"strings"
)
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
if constant.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
workerUrl := constant.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
}
// post request to worker
data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
} else {
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
return http.Get(originUrl)
}
}

Some files were not shown because too many files have changed in this diff Show More