Compare commits

..

74 Commits

Author SHA1 Message Date
1808837298@qq.com
8b67664995 feat: 上游渠道为OpenAI渠道类型时,透传请求 (close #532) 2024-10-15 18:37:44 +08:00
1808837298@qq.com
ade6d0f56a fix: 修复Playground分组无用户分组 (close #529) 2024-10-14 16:22:38 +08:00
1808837298@qq.com
f599c65944 fix: 修复用户可选分组不能选择用户分组 (close #528) 2024-10-14 16:22:22 +08:00
1808837298@qq.com
40baa636e4 fix: 修复自定义聊天bug
(cherry picked from commit 8d41c17ccf19cb29100dbe506d3d42a6be822ff9)
2024-10-13 00:21:52 +08:00
1808837298@qq.com
d6359ec4ff feat: 完善自定义聊天配置 2024-10-12 21:09:59 +08:00
1808837298@qq.com
89ddf83b44 feat: 弃用旧的聊天配置 2024-10-12 21:09:59 +08:00
1808837298@qq.com
6a8a4bcf65 fix: playground group 2024-10-10 13:39:09 +08:00
1808837298@qq.com
e298f2e5a4 feat: playground token name 2024-10-10 13:34:29 +08:00
1808837298@qq.com
8cea6dff4a feat: support embedding encoding_format param 2024-10-10 13:23:12 +08:00
1808837298@qq.com
5035cd054a feat: update aws claude 2024-10-09 00:42:36 +08:00
1808837298@qq.com
02c0c6501e feat: update auto disable 2024-10-08 23:15:57 +08:00
1808837298@qq.com
f0b808a41d feat: update model ratio 2024-10-03 21:12:09 +08:00
1808837298@qq.com
31d84ee32f feat: update model ratio 2024-10-03 20:48:47 +08:00
1808837298@qq.com
9969ed2d7c feat: update model ratio 2024-10-03 20:47:54 +08:00
1808837298@qq.com
746311242b fix: playground气泡溢出 #511 2024-09-27 20:49:26 +08:00
1808837298@qq.com
04a68a85dd feat: 优化playground样式 2024-09-27 20:49:25 +08:00
1808837298@qq.com
f9ba10f180 fix: playground max_tokens #512 #511 2024-09-27 20:18:53 +08:00
Calcium-Ion
334a6f8280 Update README.md 2024-09-26 01:54:33 +08:00
1808837298@qq.com
0cf53ac5ff feat: Playground相关接口禁用AccessToken 2024-09-26 01:49:35 +08:00
Calcium-Ion
af02cdc58b Merge pull request #509 from Calcium-Ion/playground
feat: playground
2024-09-26 01:00:33 +08:00
1808837298@qq.com
9a4ca1e210 feat: playground 2024-09-26 00:59:09 +08:00
1808837298@qq.com
9fe1f35fd1 fix: 第三方登录注销 #500 2024-09-25 17:15:59 +08:00
1808837298@qq.com
972ac1ee0f fix: 第三方登录注销 #500 2024-09-25 17:13:28 +08:00
1808837298@qq.com
0f95502b04 feat: 更新令牌生成算法 2024-09-25 16:31:25 +08:00
1808837298@qq.com
b58b1dc0ec feat: 更新令牌生成算法 2024-09-25 16:31:25 +08:00
1808837298@qq.com
05d9aa61df feat: 不自动生成系统访问令牌 2024-09-25 16:31:25 +08:00
1808837298@qq.com
221894d972 fix: error user role 2024-09-24 17:49:57 +08:00
1808837298@qq.com
50eab6b4e4 chore: 更新令牌分组描述 2024-09-22 19:43:06 +08:00
1808837298@qq.com
ed972eef06 feat: pricing page support multi groups #487 2024-09-22 17:44:57 +08:00
CalciumIon
c6ff785a83 feat: 无可选分组时关闭令牌分组功能 #485 2024-09-19 03:01:33 +08:00
CalciumIon
2e734e0c37 chore: 令牌分组描述歧义 2024-09-19 02:52:25 +08:00
CalciumIon
af33f36c7b feat: update gemini flash completion ratio #479 2024-09-18 20:39:06 +08:00
CalciumIon
3aa86a8cd9 feat: update gemini completion ratio #479 2024-09-18 20:37:22 +08:00
CalciumIon
af7fecbfa7 fix: 使用令牌分组时 "/v1/models" 返回模型不正确 #481 2024-09-18 19:19:37 +08:00
CalciumIon
3fbdd502b6 fix: token group #477 2024-09-18 18:55:11 +08:00
CalciumIon
052bc2075b feat: 令牌分组 2024-09-18 05:19:49 +08:00
Calcium-Ion
5f3798053f Create FUNDING.yml 2024-09-18 01:41:31 +08:00
CalciumIon
e31022c676 Update logo 2024-09-18 01:25:00 +08:00
Calcium-Ion
fff7609f06 Merge pull request #439 from guoruqiang/main
改进了聊天页面,增加了初始令牌,方便用户注册后即可使用聊天功能。
2024-09-17 23:14:19 +08:00
CalciumIon
9032b5cfbf fix: 初始令牌 2024-09-17 23:07:16 +08:00
CalciumIon
131453dac8 Update README.md 2024-09-17 23:01:34 +08:00
CalciumIon
ed948c121a Merge branch 'main' into g-main
# Conflicts:
#	web/src/App.js
2024-09-17 22:50:59 +08:00
CalciumIon
a03cd15505 fix: '/v1/models' #474 2024-09-17 22:41:54 +08:00
CalciumIon
02f5137781 fix: '/v1/models' #474 2024-09-17 22:39:58 +08:00
CalciumIon
e6df0ed20c fix: '/vi/models' #474 2024-09-17 22:36:20 +08:00
CalciumIon
f505afdc10 feat: 添加令牌ip白名单功能 2024-09-17 20:49:51 +08:00
CalciumIon
feb1d76942 feat: 优化界面显示 2024-09-17 19:55:18 +08:00
CalciumIon
6263616cd9 Update README.md 2024-09-17 03:18:12 +08:00
GuoRuqiang
6bbf1d4843 Merge branch 'Calcium-Ion:main' into main 2024-09-14 19:00:03 +08:00
1808837298@qq.com
13c993d87e feat: format o1 model max tokens param 2024-09-14 16:11:38 +08:00
CalciumIon
cb73889353 feat: support o1 channel test 2024-09-13 03:17:04 +08:00
CalciumIon
804aad3f37 feat: support o1 channel test 2024-09-13 03:15:32 +08:00
CalciumIon
3af62a3efa feat: support OpenAI o1-preview and o1-mini 2024-09-13 01:22:27 +08:00
CalciumIon
be54369c12 chore: update footer 2024-09-12 18:43:01 +08:00
CalciumIon
0cbf8e07e7 feat: support ollama multi-text embedding 2024-09-12 18:29:45 +08:00
Calcium-Ion
1675679be9 Merge pull request #464 from Yan-Zero/main
fix: tool use in claude and add gemini mapping
2024-09-12 05:04:19 +08:00
Yan
0b5f2a7089 add gemini exp 2024-09-11 19:37:03 +08:00
Yan Tau
b5bb708072 Merge branch 'Calcium-Ion:main' into main 2024-09-11 19:29:50 +08:00
CalciumIon
2650ec9b59 feat: claude response return model name 2024-09-11 19:12:55 +08:00
CalciumIon
d168a685c1 fix: cohere SafetyMode 2024-09-11 19:12:32 +08:00
GuoRuqiang
a0d20896b3 Merge branch 'Calcium-Ion:main' into main 2024-09-08 15:56:54 +08:00
Yan
0ada2371b6 fix: tool use in claude 2024-09-05 00:53:00 +08:00
GuoRuqiang
a0673ef2b6 Merge branch 'Calcium-Ion:main' into main 2024-09-02 21:53:54 +08:00
GuoRuqiang
2223aeb022 Merge branch 'Calcium-Ion:main' into main 2024-08-29 19:42:03 +08:00
GuoRuqiang
ecf2f7f212 Merge branch 'Calcium-Ion:main' into main 2024-08-28 21:44:54 +08:00
GuoRuqiang
033359e93c Merge branch 'Calcium-Ion:main' into main 2024-08-28 10:44:14 +08:00
GuoRuqiang
1379d7f184 Merge pull request #2 from j471782517/main
增加环境变量GENERATE_DEFAULT_TOKEN 设置之后将生成初始令牌,默认关闭。
2024-08-25 02:53:47 +08:00
Jin Weihan
716bf6f48a 增加环境变量GENERATE_DEFAULT_TOKEN 设置之后将生成初始令牌,默认关闭。 2024-08-24 18:44:37 +00:00
GuoRuqiang
2422eb2820 Merge branch 'Calcium-Ion:main' into main 2024-08-25 01:55:23 +08:00
GuoRuqiang
c97e2875b4 增加注册自动生成初始令牌。 2024-08-18 15:12:59 +00:00
GuoRuqiang
64794630c8 修改提示时间。 2024-08-17 16:59:31 +00:00
GuoRuqiang
fc5055c766 update App.js 2024-08-17 16:20:41 +00:00
GuoRuqiang
27eb358497 重新修改了chat 2024-08-17 16:17:24 +00:00
GuoRuqiang
6810ee0a28 Update Chat
修改chat界面,配合nextChat等前端可以自动传入第一个已启用令牌,
2024-08-17 23:09:45 +08:00
127 changed files with 6873 additions and 8163 deletions

12
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -1,5 +1,5 @@
blank_issues_enabled: false blank_issues_enabled: false
contact_links: contact_links:
- name: 交流社区 - name: 项目群聊
url: https://linux.do url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
about: 项目交流社区 about: QQ 群629454374

View File

@@ -1,6 +1,9 @@
name: Publish Docker image (amd64) name: Publish Docker image (amd64)
on: on:
push:
tags:
- '*'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
name: name:
@@ -39,7 +42,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
pengzhile/new-api calciumion/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images

View File

@@ -49,7 +49,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
pengzhile/new-api calciumion/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images

View File

@@ -1,6 +1,13 @@
<div align="center">
![new-api](/web/public/logo.png)
# New API # New API
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
> [!NOTE] > [!NOTE]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发 > 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
@@ -44,7 +51,7 @@
## 模型支持 ## 模型支持
此版本额外支持以下模型: 此版本额外支持以下模型:
1. 第三方模型 **gps** gpt-4-gizmo-*, g-* 1. 第三方模型 **gps** gpt-4-gizmo-*
2. 智谱glm-4vglm-4v识图 2. 智谱glm-4vglm-4v识图
3. Anthropic Claude 3 3. Anthropic Claude 3
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改 4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
@@ -56,9 +63,10 @@
10. Dify 10. Dify
11. Vertex AI目前兼容ClaudeGeminiLlama3.1 11. Vertex AI目前兼容ClaudeGeminiLlama3.1
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。 您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 比原版One API多出的配置 ## 比原版One API多出的配置
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` - `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。 - `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
@@ -115,24 +123,19 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
## Suno接口设置文档 ## Suno接口设置文档
[对接文档](Suno.md) [对接文档](Suno.md)
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
## 界面截图 ## 界面截图
![796df8d287b7b7bd7853b2497e7df511](https://github.com/user-attachments/assets/255b5e97-2d3a-4434-b4fa-e922ad88ff5a)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/d1ac216e-0804-4105-9fdc-66b35022d861)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/f4f40ed4-8ccb-43d7-a580-90677827646d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/90d7d763-6a77-4b36-9f76-2bb30f18583d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/e414228a-3c35-429a-b298-6451d76d9032)
夜间模式 夜间模式
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/5b3228e8-2556-44f7-97d6-4f8d8ee6effa)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="200">
## 相关项目 ## 相关项目
- [One API](https://github.com/songquanpeng/one-api):原版项目 - [One API](https://github.com/songquanpeng/one-api):原版项目
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)Midjourney接口支持 - [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)Midjourney接口支持

View File

@@ -9,20 +9,9 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// Pay Settings
var StripeApiSecret = ""
var StripeWebhookSecret = ""
var StripePriceId = ""
var PaymentEnabled = false
var StripeUnitPrice = 8.0
var MinTopUp = 5
var StartTime = time.Now().Unix() // unit: second var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API" var SystemName = "New API"
var ServerAddress = "http://localhost:3000"
var OutProxyUrl = ""
var Footer = "" var Footer = ""
var Logo = "" var Logo = ""
var TopUpLink = "" var TopUpLink = ""
@@ -52,12 +41,10 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false var GitHubOAuthEnabled = false
var LinuxDoOAuthEnabled = false
var WeChatAuthEnabled = false var WeChatAuthEnabled = false
var TelegramOAuthEnabled = false var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false var TurnstileCheckEnabled = false
var RegisterEnabled = true var RegisterEnabled = true
var UserSelfDeletionEnabled = false
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制 var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制 var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
@@ -88,10 +75,6 @@ var SMTPToken = ""
var GitHubClientId = "" var GitHubClientId = ""
var GitHubClientSecret = "" var GitHubClientSecret = ""
var LinuxDoClientId = ""
var LinuxDoClientSecret = ""
var LinuxDoMinLevel = 0
var WeChatServerAddress = "" var WeChatServerAddress = ""
var WeChatServerToken = "" var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = "" var WeChatAccountQRCodeImageURL = ""
@@ -143,6 +126,10 @@ const (
RoleRootUser = 100 RoleRootUser = 100
) )
func IsValidateRole(role int) bool {
return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
}
var ( var (
FileUploadPermission = RoleGuestUser FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser FileDownloadPermission = RoleGuestUser
@@ -196,12 +183,6 @@ const (
ChannelStatusAutoDisabled = 3 ChannelStatusAutoDisabled = 3
) )
const (
TopUpStatusPending = "pending"
TopUpStatusSuccess = "success"
TopUpStatusExpired = "expired"
)
const ( const (
ChannelTypeUnknown = 0 ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1 ChannelTypeOpenAI = 1

View File

@@ -1,84 +0,0 @@
package common
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"math/rand"
"time"
)
func Sha256Raw(data string) []byte {
h := sha256.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1Raw(data []byte) []byte {
h := sha1.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1(data string) string {
return hex.EncodeToString(Sha1Raw([]byte(data)))
}
func HmacSha256Raw(message, key []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(message)
return h.Sum(nil)
}
func HmacSha256(message, key string) string {
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
}
func RandomBytes(length int) []byte {
rand.Seed(time.Now().UnixNano())
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return b
}
func RandomString(length int) string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomHex(length int) string {
const chars = "abcdef0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomNumber(length int) string {
const chars = "0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomUUID() string {
all := RandomHex(32)
return all[:8] + "-" + all[8:12] + "-" + all[12:16] + "-" + all[16:20] + "-" + all[20:]
}

View File

@@ -2,6 +2,7 @@ package common
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
@@ -99,10 +100,12 @@ func LogQuota(quota int) string {
} }
} }
func LogQuotaF(quota float64) string { // LogJson 仅供测试使用 only for test
if DisplayInCurrencyEnabled { func LogJson(ctx context.Context, msg string, obj any) {
return fmt.Sprintf("%.6f 额度", quota/QuotaPerUnit) jsonStr, err := json.Marshal(obj)
} else { if err != nil {
return fmt.Sprintf("%d 点额度", int64(quota)) LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
return
} }
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
} }

View File

@@ -23,42 +23,48 @@ const (
var defaultModelRatio = map[string]float64{ var defaultModelRatio = map[string]float64{
//"midjourney": 50, //"midjourney": 50,
"gpt-4-gizmo-*": 15, "gpt-4-gizmo-*": 15,
"g-*": 15, "gpt-4o-gizmo-*": 2.5,
"gpt-4": 15, "gpt-4-all": 15,
"gpt-4-0314": 15, "gpt-4o-all": 15,
"gpt-4-0613": 15, "gpt-4": 15,
"gpt-4-32k": 30, //"gpt-4-0314": 15, //deprecated
"gpt-4-32k-0314": 30, "gpt-4-0613": 15,
"gpt-4-32k-0613": 30, "gpt-4-32k": 30,
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens //"gpt-4-32k-0314": 30, //deprecated
"gpt-4o-mini-2024-07-18": 0.075, "gpt-4-32k-0613": 30,
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.005 / 1K tokens "gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo": 5, // $0.01 / 1K tokens "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens "gpt-4o": 2.5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens "gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens "gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens "o1-preview": 7.5,
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens "o1-preview-2024-09-12": 7.5,
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens "o1-mini": 1.5,
"gpt-3.5-turbo-0301": 0.75, "o1-mini-2024-09-12": 1.5,
"gpt-3.5-turbo-0613": 0.75, "gpt-4o-mini": 0.075,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-4o-mini-2024-07-18": 0.075,
"gpt-3.5-turbo-16k-0613": 1.5, "gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens "gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0125": 0.25, //"gpt-3.5-turbo-0301": 0.75, //deprecated
"babbage-002": 0.2, // $0.0004 / 1K tokens "gpt-3.5-turbo-0613": 0.75,
"davinci-002": 1, // $0.002 / 1K tokens "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"text-ada-001": 0.2, "gpt-3.5-turbo-16k-0613": 1.5,
"text-babbage-001": 0.25, "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"text-curie-001": 1, "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"text-davinci-002": 10, "gpt-3.5-turbo-0125": 0.25,
"text-davinci-003": 10, "babbage-002": 0.2, // $0.0004 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
//"text-davinci-002": 10,
//"text-davinci-003": 10,
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
@@ -80,9 +86,9 @@ var defaultModelRatio = map[string]float64{
"claude-2.0": 4, // $8 / 1M tokens "claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens "claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens "claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens "claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens "claude-3-5-sonnet-20240620": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB, "ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB, "ERNIE-3.5-8K-0205": 0.024 * RMB,
@@ -104,8 +110,10 @@ var defaultModelRatio = map[string]float64{
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1, "gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1, "gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1, "gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
"gemini-1.5-flash-latest": 1, "gemini-1.5-flash-latest": 1,
"gemini-1.5-flash-exp-0827": 1,
"gemini-1.0-pro-latest": 1, "gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1, "gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1, "gemini-ultra": 1,
@@ -174,10 +182,8 @@ var defaultModelRatio = map[string]float64{
var defaultModelPrice = map[string]float64{ var defaultModelPrice = map[string]float64{
"suno_music": 0.1, "suno_music": 0.1,
"suno_lyrics": 0.01, "suno_lyrics": 0.01,
"dall-e-2": 0.02,
"dall-e-3": 0.04, "dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1, "gpt-4-gizmo-*": 0.1,
"g-*": 0.1,
"mj_imagine": 0.1, "mj_imagine": 0.1,
"mj_variation": 0.1, "mj_variation": 0.1,
"mj_reroll": 0.1, "mj_reroll": 0.1,
@@ -207,10 +213,9 @@ var (
var CompletionRatio map[string]float64 = nil var CompletionRatio map[string]float64 = nil
var defaultCompletionRatio = map[string]float64{ var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2, "gpt-4-gizmo-*": 2,
"g-*": 2, "gpt-4o-gizmo-*": 3,
"gpt-4-all": 2, "gpt-4-all": 2,
"gpt-4o-all": 2,
} }
func GetModelPriceMap() map[string]float64 { func GetModelPriceMap() map[string]float64 {
@@ -243,8 +248,9 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
GetModelPriceMap() GetModelPriceMap()
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") { }
name = "g-*" if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
} }
price, ok := modelPriceMap[name] price, ok := modelPriceMap[name]
if !ok { if !ok {
@@ -285,8 +291,6 @@ func GetModelRatio(name string) float64 {
GetModelRatioMap() GetModelRatioMap()
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") {
name = "g-*"
} }
ratio, ok := modelRatioMap[name] ratio, ok := modelRatioMap[name]
if !ok { if !ok {
@@ -327,47 +331,51 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
func GetCompletionRatio(name string) float64 { func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} else if strings.HasPrefix(name, "g-") { }
name = "g-*" if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
return 3
}
if strings.HasPrefix(name, "gpt-4o") {
if name == "gpt-4o-2024-05-13" {
return 3
}
return 4
}
return 2
}
if strings.HasPrefix(name, "o1-") {
return 4
}
if name == "chatgpt-4o-latest" {
return 4
}
if strings.Contains(name, "claude-instant-1") {
return 3
} else if strings.Contains(name, "claude-2") {
return 3
} else if strings.Contains(name, "claude-3") {
return 5
} }
if strings.HasPrefix(name, "gpt-3.5") { if strings.HasPrefix(name, "gpt-3.5") {
if strings.HasSuffix(name, "0125") { if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
return 3 return 3
} }
if strings.HasSuffix(name, "1106") { if strings.HasSuffix(name, "1106") {
return 2 return 2
} }
if name == "gpt-3.5-turbo" {
return 3
}
return 4.0 / 3.0 return 4.0 / 3.0
} }
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" {
if strings.HasPrefix(name, "gpt-4o-mini") || "gpt-4o-2024-08-06" == name {
return 4
}
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || strings.HasPrefix(name, "gpt-4o") {
return 3
}
return 2
}
if name == "chatgpt-4o-latest" {
return 3
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3
} else if strings.HasPrefix(name, "claude-2") {
return 3
} else if strings.HasPrefix(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "mistral-") { if strings.HasPrefix(name, "mistral-") {
return 3 return 3
} }
if strings.HasPrefix(name, "gemini-") { if strings.HasPrefix(name, "gemini-") {
return 3 return 4
} }
if strings.HasPrefix(name, "command") { if strings.HasPrefix(name, "command") {
switch name { switch name {

46
common/user_groups.go Normal file
View File

@@ -0,0 +1,46 @@
package common
import (
"encoding/json"
)
var UserUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(UserUsableGroups)
if err != nil {
SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
UserUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
}
func GetUserUsableGroups(userGroup string) map[string]string {
if userGroup == "" {
// 如果userGroup为空返回UserUsableGroups
return UserUsableGroups
}
// 如果userGroup不在UserUsableGroups中返回UserUsableGroups + userGroup
if _, ok := UserUsableGroups[userGroup]; !ok {
appendUserUsableGroups := make(map[string]string)
for k, v := range UserUsableGroups {
appendUserUsableGroups[k] = v
}
appendUserUsableGroups[userGroup] = "用户分组"
return appendUserUsableGroups
}
// 如果userGroup在UserUsableGroups中返回UserUsableGroups
return UserUsableGroups
}
func GroupInUserUsableGroups(groupName string) bool {
_, ok := UserUsableGroups[groupName]
return ok
}

View File

@@ -1,17 +1,15 @@
package common package common
import ( import (
"context" crand "crypto/rand"
"errors" "encoding/base64"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/net/proxy"
"html/template" "html/template"
"log" "log"
"math/big"
"math/rand" "math/rand"
"net" "net"
"net/http"
"net/url"
"os/exec" "os/exec"
"runtime" "runtime"
"strconv" "strconv"
@@ -133,6 +131,11 @@ func IntMax(a int, b int) int {
} }
} }
func IsIP(s string) bool {
ip := net.ParseIP(s)
return ip != nil
}
func GetUUID() string { func GetUUID() string {
code := uuid.New().String() code := uuid.New().String()
code = strings.Replace(code, "-", "", -1) code = strings.Replace(code, "-", "", -1)
@@ -142,24 +145,35 @@ func GetUUID() string {
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() { func init() {
rand.Seed(time.Now().UnixNano()) rand.New(rand.NewSource(time.Now().UnixNano()))
} }
func GenerateKey() string { func GenerateRandomCharsKey(length int) (string, error) {
//rand.Seed(time.Now().UnixNano()) b := make([]byte, length)
key := make([]byte, 48) maxI := big.NewInt(int64(len(keyChars)))
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))] for i := range b {
} n, err := crand.Int(crand.Reader, maxI)
uuid_ := GetUUID() if err != nil {
for i := 0; i < 32; i++ { return "", err
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
} }
key[i+16] = c b[i] = keyChars[n.Int64()]
} }
return string(key)
return string(b), nil
}
func GenerateRandomKey(length int) (string, error) {
bytes := make([]byte, length*3/4) // 对于48位的输出这里应该是36
if _, err := crand.Read(bytes); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(bytes), nil
}
func GenerateKey() (string, error) {
//rand.Seed(time.Now().UnixNano())
return GenerateRandomCharsKey(48)
} }
func GetRandomInt(max int) int { func GetRandomInt(max int) int {
@@ -192,56 +206,3 @@ func RandomSleep() {
// Sleep for 0-3000 ms // Sleep for 0-3000 ms
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
} }
func GetProxiedHttpClient(proxyUrl string) (*http.Client, error) {
if "" == proxyUrl {
return &http.Client{}, nil
}
u, err := url.Parse(proxyUrl)
if err != nil {
return nil, err
}
if strings.HasPrefix(proxyUrl, "http") {
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(u),
},
}, nil
} else if strings.HasPrefix(proxyUrl, "socks") {
dialer, err := proxy.FromURL(u, proxy.Direct)
if err != nil {
return nil, err
}
return &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
},
},
}, nil
}
return nil, errors.New("unsupported proxy type")
}
func ProxiedHttpGet(url, proxyUrl string) (*http.Response, error) {
client, err := GetProxiedHttpClient(proxyUrl)
if err != nil {
return nil, err
}
return client.Get(url)
}
func ProxiedHttpHead(url, proxyUrl string) (*http.Response, error) {
client, err := GetProxiedHttpClient(proxyUrl)
if err != nil {
return nil, err
}
return client.Head(url)
}

35
constant/chat.go Normal file
View File

@@ -0,0 +1,35 @@
package constant
import (
"encoding/json"
"one-api/common"
)
var Chats = []map[string]string{
{
"ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
},
{
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
},
{
"AMA 问天": "ama://set-api-key?server={address}&key={key}",
},
{
"OpenCat": "opencat://team/join?domain={address}&token={key}",
},
}
func UpdateChatsByJsonString(jsonString string) error {
Chats = make([]map[string]string, 0)
return json.Unmarshal([]byte(jsonString), &Chats)
}
func Chats2JsonString() string {
jsonBytes, err := json.Marshal(Chats)
if err != nil {
common.SysError("error marshalling chats: " + err.Error())
return "[]"
}
return string(jsonBytes)
}

View File

@@ -20,14 +20,16 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{ var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta", "gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta", "gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta", "gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta", "gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-flash-latest": "v1beta", "gemini-1.5-pro-exp-0827": "v1beta",
"gemini-1.5-flash-001": "v1beta", "gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash": "v1beta", "gemini-1.5-flash-exp-0827": "v1beta",
"gemini-ultra": "v1beta", "gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
} }
func InitEnv() { func InitEnv() {
@@ -44,3 +46,6 @@ func InitEnv() {
} }
} }
} }
// 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)

View File

@@ -0,0 +1,9 @@
package constant
var ServerAddress = "http://localhost:3000"
var WorkerUrl = ""
var WorkerValidKey = ""
func EnableWorker() bool {
return WorkerUrl != ""
}

View File

@@ -20,6 +20,7 @@ import (
"one-api/relay/constant" "one-api/relay/constant"
"one-api/service" "one-api/service"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@@ -81,8 +82,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
request := buildTestRequest() request := buildTestRequest(testModel)
request.Model = testModel
meta.UpstreamModelName = testModel meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
@@ -141,17 +141,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return nil, nil return nil, nil
} }
func buildTestRequest() *dto.GeneralOpenAIRequest { func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest := &dto.GeneralOpenAIRequest{ testRequest := &dto.GeneralOpenAIRequest{
Model: "", // this will be set later Model: "", // this will be set later
MaxTokens: 1, Stream: false,
Stream: false, }
if strings.HasPrefix(model, "o1-") {
testRequest.MaxCompletionTokens = 1
} else {
testRequest.MaxTokens = 1
} }
content, _ := json.Marshal("hi") content, _ := json.Marshal("hi")
testMessage := dto.Message{ testMessage := dto.Message{
Role: "user", Role: "user",
Content: content, Content: content,
} }
testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage) testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest return testRequest
} }
@@ -226,26 +231,22 @@ func testAllChannels(notify bool) error {
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
ban := false shouldBanChannel := false
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
ban = true
}
// request error disables the channel // request error disables the channel
if openaiWithStatusErr != nil { if openaiWithStatusErr != nil {
oaiErr := openaiWithStatusErr.Error oaiErr := openaiWithStatusErr.Error
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
} }
// parse *int to bool if milliseconds > disableThreshold {
if !channel.GetAutoBan() { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
ban = false shouldBanChannel = true
} }
// disable channel // disable channel
if ban && isChannelEnabled { if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
service.DisableChannel(channel.Id, channel.Name, err.Error()) service.DisableChannel(channel.Id, channel.Name, err.Error())
} }

View File

@@ -112,7 +112,9 @@ func GitHubOAuth(c *gin.Context) {
user := model.User{ user := model.User{
GitHubId: githubUser.Login, GitHubId: githubUser.Login,
} }
// IsGitHubIdAlreadyTaken is unscoped
if model.IsGitHubIdAlreadyTaken(user.GitHubId) { if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
// FillUserByGitHubId is scoped
err := user.FillUserByGitHubId() err := user.FillUserByGitHubId()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -121,10 +123,16 @@ func GitHubOAuth(c *gin.Context) {
}) })
return return
} }
// if user.Id == 0 , user has been deleted
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else { } else {
if common.RegisterEnabled { if common.RegisterEnabled {
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" { if githubUser.Name != "" {
user.DisplayName = githubUser.Name user.DisplayName = githubUser.Name
@@ -135,7 +143,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = common.RoleCommonUser user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = common.UserStatusEnabled
if err := user.Insert(user.InviterId); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model"
) )
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
@@ -17,3 +18,22 @@ func GetGroups(c *gin.Context) {
"data": groupNames, "data": groupNames,
}) })
} }
func GetUserGroups(c *gin.Context) {
usableGroups := make(map[string]string)
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.CacheGetUserGroup(userId)
for groupName, _ := range common.GroupRatio {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := common.GetUserUsableGroups(userGroup)
if _, ok := userUsableGroups[groupName]; ok {
usableGroups[groupName] = userUsableGroups[groupName]
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": usableGroups,
})
}

View File

@@ -1,239 +0,0 @@
package controller
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
type LinuxDoOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type LinuxDoUser struct {
ID int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
form := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
}
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", "Basic "+auth)
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse LinuxDoOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res2.Body.Close()
var linuxdoUser LinuxDoUser
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
if err != nil {
return nil, err
}
if linuxdoUser.ID == 0 {
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
}
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
return nil, errors.New("用户 LINUX DO 信任等级不足!")
}
return &linuxdoUser, nil
}
func LinuxDoOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
err := user.FillUserByLinuxDoId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
affCode := c.Query("aff")
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
if linuxdoUser.Name != "" {
user.DisplayName = linuxdoUser.Name
} else {
user.DisplayName = linuxdoUser.Username
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 LINUX DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -192,7 +192,7 @@ func DeleteHistoryLogs(c *gin.Context) {
}) })
return return
} }
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100) count, err := model.DeleteOldLog(targetTimestamp)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@@ -233,7 +233,7 @@ func GetAllMidjourney(c *gin.Context) {
} }
if constant.MjForwardUrlEnabled { if constant.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney
} }
} }
@@ -265,7 +265,7 @@ func GetUserMidjourney(c *gin.Context) {
} }
if constant.MjForwardUrlEnabled { if constant.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney
} }
} }

View File

@@ -38,8 +38,6 @@ func GetStatus(c *gin.Context) {
"email_verification": common.EmailVerificationEnabled, "email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled, "github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId, "github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
"linuxdo_client_id": common.LinuxDoClientId,
"telegram_oauth": common.TelegramOAuthEnabled, "telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName, "telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName, "system_name": common.SystemName,
@@ -47,9 +45,9 @@ func GetStatus(c *gin.Context) {
"footer_html": common.Footer, "footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled, "wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress, "server_address": constant.ServerAddress,
"stripe_unit_price": common.StripeUnitPrice, "price": constant.Price,
"min_topup": common.MinTopUp, "min_topup": constant.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey, "turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink, "top_up_link": common.TopUpLink,
@@ -63,8 +61,9 @@ func GetStatus(c *gin.Context) {
"enable_data_export": common.DataExportEnabled, "enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime, "data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar, "default_collapse_sidebar": common.DefaultCollapseSidebar,
"payment_enabled": common.PaymentEnabled, "enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "",
"mj_notify_enabled": constant.MjNotifyEnabled, "mj_notify_enabled": constant.MjNotifyEnabled,
"chats": constant.Chats,
}, },
}) })
return return
@@ -206,7 +205,7 @@ func SendPasswordResetEmail(c *gin.Context) {
} }
code := common.GenerateVerificationCode(0) code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName) subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ "<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+

View File

@@ -137,31 +137,63 @@ func init() {
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
models := model.GetGroupModels(user.Group)
userOpenAiModels := make([]dto.OpenAIModels, 0) userOpenAiModels := make([]dto.OpenAIModels, 0)
permission := getPermission() permission := getPermission()
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok { modelLimitEnable := c.GetBool("token_model_limit_enabled")
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s]) if modelLimitEnable {
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else { } else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ tokenModelLimit = map[string]bool{}
Id: s, }
Object: "model", for allowModel, _ := range tokenModelLimit {
Created: 1626777600, if _, ok := openAIModelsMap[allowModel]; ok {
OwnedBy: "custom", userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
Permission: permission, } else {
Root: s, userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Parent: nil, Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: allowModel,
Parent: nil,
})
}
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "get user group failed",
}) })
return
}
group := userGroup
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
group = tokenGroup
}
models := model.GetGroupModels(group)
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: s,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: s,
Parent: nil,
})
}
} }
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{

View File

@@ -50,14 +50,6 @@ func UpdateOption(c *gin.Context) {
}) })
return return
} }
case "LinuxDoOAuthEnabled":
if option.Value == "true" && common.LinuxDoClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 LINUX DO OAuth请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret",
})
return
}
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -7,18 +7,11 @@ import (
) )
func GetPricing(c *gin.Context) { func GetPricing(c *gin.Context) {
userId := c.GetInt("id") pricing := model.GetPricing()
// if no login, get default group ratio
groupRatio := common.GetGroupRatio("default")
group, err := model.CacheGetUserGroup(userId)
if err == nil {
groupRatio = common.GetGroupRatio(group)
}
pricing := model.GetPricing(group)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"data": pricing, "data": pricing,
"group_ratio": groupRatio, "group_ratio": common.GroupRatio,
}) })
} }

View File

@@ -38,6 +38,58 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
return err return err
} }
func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode
defer func() {
if openaiErr != nil {
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}()
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
return
}
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
return
}
if playgroundRequest.Model == "" {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
return
}
c.Set("original_model", playgroundRequest.Model)
group := playgroundRequest.Group
userGroup := c.GetString("group")
if group == "" {
group = userGroup
} else {
if !common.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
}
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
Relay(c)
}
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path) relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)

View File

@@ -1,97 +0,0 @@
package controller
import (
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/webhook"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
func StripeWebhook(c *gin.Context) {
payload, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
c.AbortWithStatus(http.StatusServiceUnavailable)
return
}
signature := c.GetHeader("Stripe-Signature")
endpointSecret := common.StripeWebhookSecret
event, err := webhook.ConstructEvent(payload, signature, endpointSecret)
if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err)
c.AbortWithStatus(http.StatusBadRequest)
return
}
switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(event)
case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(event)
default:
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
}
c.Status(http.StatusOK)
}
func sessionCompleted(event stripe.Event) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "complete" != status {
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
return
}
err := model.Recharge(referenceId, customerId)
if err != nil {
log.Println(err.Error(), referenceId)
return
}
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency"))
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
}
func sessionExpired(event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "expired" != status {
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
return
}
if "" == referenceId {
log.Println("未提供支付单号")
return
}
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Println("充值订单不存在", referenceId)
return
}
if topUp.Status != common.TopUpStatusPending {
log.Println("充值订单状态错误", referenceId)
}
topUp.Status = common.TopUpStatusExpired
err := topUp.Update()
if err != nil {
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
return
}
log.Println("充值订单已过期", referenceId)
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"io" "io"
"net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"sort" "sort"
@@ -48,6 +49,13 @@ func TelegramBind(c *gin.Context) {
}) })
return return
} }
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
user.TelegramId = telegramId user.TelegramId = telegramId
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{

View File

@@ -123,10 +123,19 @@ func AddToken(c *gin.Context) {
}) })
return return
} }
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Key: common.GenerateKey(), Key: key,
CreatedTime: common.GetTimestamp(), CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(), AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
@@ -134,6 +143,8 @@ func AddToken(c *gin.Context) {
UnlimitedQuota: token.UnlimitedQuota, UnlimitedQuota: token.UnlimitedQuota,
ModelLimitsEnabled: token.ModelLimitsEnabled, ModelLimitsEnabled: token.ModelLimitsEnabled,
ModelLimits: token.ModelLimits, ModelLimits: token.ModelLimits,
AllowIps: token.AllowIps,
Group: token.Group,
} }
err = cleanToken.Insert() err = cleanToken.Insert()
if err != nil { if err != nil {
@@ -221,6 +232,8 @@ func UpdateToken(c *gin.Context) {
cleanToken.UnlimitedQuota = token.UnlimitedQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
cleanToken.ModelLimits = token.ModelLimits cleanToken.ModelLimits = token.ModelLimits
cleanToken.AllowIps = token.AllowIps
cleanToken.Group = token.Group
} }
err = cleanToken.Update() err = cleanToken.Update()
if err != nil { if err != nil {

View File

@@ -1,20 +1,22 @@
package controller package controller
import "C"
import ( import (
"fmt" "fmt"
"github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v76" "github.com/samber/lo"
"github.com/stripe/stripe-go/v76/checkout/session"
"log" "log"
"net/url"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/model" "one-api/model"
"one-api/service"
"strconv" "strconv"
"strings" "sync"
"time" "time"
) )
type PayRequest struct { type EpayRequest struct {
Amount int `json:"amount"` Amount int `json:"amount"`
PaymentMethod string `json:"payment_method"` PaymentMethod string `json:"payment_method"`
TopUpCode string `json:"top_up_code"` TopUpCode string `json:"top_up_code"`
@@ -25,114 +27,201 @@ type AmountRequest struct {
TopUpCode string `json:"top_up_code"` TopUpCode string `json:"top_up_code"`
} }
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) { func GetEpayClient() *epay.Client {
if !strings.HasPrefix(common.StripeApiSecret, "sk_") { if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" {
return "", fmt.Errorf("无效的Stripe API密钥") return nil
} }
withUrl, err := epay.NewClient(&epay.Config{
stripe.Key = common.StripeApiSecret PartnerID: constant.EpayId,
Key: constant.EpayKey,
params := &stripe.CheckoutSessionParams{ }, constant.PayAddress)
ClientReferenceID: stripe.String(referenceId),
SuccessURL: stripe.String(common.ServerAddress + "/log"),
CancelURL: stripe.String(common.ServerAddress + "/topup"),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(common.StripePriceId),
Quantity: stripe.Int64(amount),
},
},
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
}
if "" == customerId {
if "" != email {
params.CustomerEmail = stripe.String(email)
}
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
} else {
params.Customer = stripe.String(customerId)
}
result, err := session.New(params)
if err != nil { if err != nil {
return "", err return nil
} }
return withUrl
return result.URL, nil
} }
func GetPayAmount(count float64) float64 { func getPayMoney(amount float64, group string) float64 {
return count * common.StripeUnitPrice if !common.DisplayInCurrencyEnabled {
} amount = amount / common.QuotaPerUnit
func GetChargedAmount(count float64, user model.User) float64 {
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
if topUpGroupRatio == 0 {
topUpGroupRatio = 1
} }
// 别问为什么用float64问就是这么点钱没必要
return count * topUpGroupRatio topupGroupRatio := common.GetTopupGroupRatio(group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
payMoney := amount * constant.Price * topupGroupRatio
return payMoney
} }
func RequestPayLink(c *gin.Context) { func getMinTopup() int {
var req PayRequest minTopup := constant.MinTopUp
if !common.DisplayInCurrencyEnabled {
minTopup = minTopup * int(common.QuotaPerUnit)
}
return minTopup
}
func RequestEpay(c *gin.Context) {
var req EpayRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": err.Error(), "data": 10}) c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return return
} }
if !common.PaymentEnabled { if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"}) c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
return
}
if req.PaymentMethod != "stripe" {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
return
}
if req.Amount > 10000 {
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) group, err := model.CacheGetUserGroup(id)
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), common.RandomString(4))
referenceId := "ref_" + common.Sha1(reference)
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, int64(req.Amount))
if err != nil { if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err) c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) return
}
payMoney := getPayMoney(float64(req.Amount), group)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return return
} }
var payType epay.PurchaseType
if req.PaymentMethod == "zfb" {
payType = epay.Alipay
}
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = epay.WechatPay
}
callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType,
ServiceTradeNo: tradeNo,
Name: fmt.Sprintf("TUC%d", req.Amount),
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
Device: epay.PC,
NotifyUrl: notifyUrl,
ReturnUrl: returnUrl,
})
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
amount := req.Amount
if !common.DisplayInCurrencyEnabled {
amount = amount / int(common.QuotaPerUnit)
}
topUp := &model.TopUp{ topUp := &model.TopUp{
UserId: id, UserId: id,
Amount: req.Amount, Amount: amount,
Money: chargedMoney, Money: payMoney,
TradeNo: referenceId, TradeNo: tradeNo,
CreateTime: time.Now().Unix(), CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending, Status: "pending",
} }
err = topUp.Insert() err = topUp.Insert()
if err != nil { if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"}) c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
return return
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
"message": "success", }
"data": gin.H{
"payLink": payLink, // tradeNo lock
}, var orderLocks sync.Map
}) var createLock sync.Mutex
// LockOrder 尝试对给定订单号加锁
func LockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if !ok {
createLock.Lock()
defer createLock.Unlock()
lock, ok = orderLocks.Load(tradeNo)
if !ok {
lock = new(sync.Mutex)
orderLocks.Store(tradeNo, lock)
}
}
lock.(*sync.Mutex).Lock()
}
// UnlockOrder 释放给定订单号的锁
func UnlockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if ok {
lock.(*sync.Mutex).Unlock()
}
}
func EpayNotify(c *gin.Context) {
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
client := GetEpayClient()
if client == nil {
log.Println("易支付回调失败 未找到配置信息")
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
return
}
}
verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus {
_, err := c.Writer.Write([]byte("success"))
if err != nil {
log.Println("易支付回调写入失败")
}
} else {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
}
log.Println("易支付回调签名验证失败")
return
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp == nil {
log.Printf("易支付回调未找到订单: %v", verifyInfo)
return
}
if topUp.Status == "pending" {
topUp.Status = "success"
err := topUp.Update()
if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp)
return
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
}
} }
func RequestAmount(c *gin.Context) { func RequestAmount(c *gin.Context) {
@@ -142,23 +231,21 @@ func RequestAmount(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return return
} }
if !common.PaymentEnabled {
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"}) if req.Amount < getMinTopup() {
return c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
}
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) group, err := model.CacheGetUserGroup(id)
payMoney := GetPayAmount(float64(req.Amount)) if err != nil {
chargedMoney := GetChargedAmount(float64(req.Amount), *user) c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
c.JSON(200, gin.H{ return
"message": "success", }
"data": gin.H{ payMoney := getPayMoney(float64(req.Amount), group)
"payAmount": strconv.FormatFloat(payMoney, 'f', 2, 64), if payMoney <= 0.01 {
"chargedAmount": strconv.FormatFloat(chargedMoney, 'f', 2, 64), c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
}, return
}) }
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
} }

View File

@@ -7,10 +7,12 @@ import (
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/constant"
) )
type LoginRequest struct { type LoginRequest struct {
@@ -66,7 +68,7 @@ func setupLogin(user *model.User, c *gin.Context) {
session.Set("username", user.Username) session.Set("username", user.Username)
session.Set("role", user.Role) session.Set("role", user.Role)
session.Set("status", user.Status) session.Set("status", user.Status)
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel) session.Set("group", user.Group)
err := session.Save() err := session.Save()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -158,8 +160,9 @@ func Register(c *gin.Context) {
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": "数据库错误,请稍后重试",
}) })
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return return
} }
if exist { if exist {
@@ -187,6 +190,48 @@ func Register(c *gin.Context) {
}) })
return return
} }
// 获取插入后的用户ID
var insertedUser model.User
if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户注册失败或用户ID获取失败",
})
return
}
// 生成默认令牌
if constant.GenerateDefaultToken {
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成默认令牌失败",
})
common.SysError("failed to generate token key: " + err.Error())
return
}
// 生成默认令牌
token := model.Token{
UserId: insertedUser.Id, // 使用插入后的用户ID
Name: cleanUser.Username + "的初始令牌",
Key: key,
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: -1, // 永不过期
RemainQuota: 500000, // 示例额度
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "创建默认令牌失败",
})
return
}
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@@ -277,7 +322,18 @@ func GenerateAccessToken(c *gin.Context) {
}) })
return return
} }
user.AccessToken = common.GetUUID() // get rand int 28-32
randI := common.GetRandomInt(4)
key, err := common.GenerateRandomKey(29 + randI)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成失败",
})
common.SysError("failed to generate key: " + err.Error())
return
}
user.SetAccessToken(key)
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -518,7 +574,7 @@ func UpdateSelf(c *gin.Context) {
return return
} }
func HardDeleteUser(c *gin.Context) { func DeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -527,7 +583,7 @@ func HardDeleteUser(c *gin.Context) {
}) })
return return
} }
originUser, err := model.GetUserByIdUnscoped(id, false) originUser, err := model.GetUserById(id, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -551,23 +607,9 @@ func HardDeleteUser(c *gin.Context) {
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
} }
func DeleteSelf(c *gin.Context) { func DeleteSelf(c *gin.Context) {
if !common.UserSelfDeletionEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "当前设置不允许用户自我删除账号",
})
return
}
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
@@ -597,6 +639,7 @@ func DeleteSelf(c *gin.Context) {
func CreateUser(c *gin.Context) { func CreateUser(c *gin.Context) {
var user model.User var user model.User
err := json.NewDecoder(c.Request.Body).Decode(&user) err := json.NewDecoder(c.Request.Body).Decode(&user)
user.Username = strings.TrimSpace(user.Username)
if err != nil || user.Username == "" || user.Password == "" { if err != nil || user.Username == "" || user.Password == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -644,8 +687,8 @@ func CreateUser(c *gin.Context) {
} }
type ManageRequest struct { type ManageRequest struct {
Username string `json:"username"` Id int `json:"id"`
Action string `json:"action"` Action string `json:"action"`
} }
// ManageUser Only admin user can do this // ManageUser Only admin user can do this
@@ -661,7 +704,7 @@ func ManageUser(c *gin.Context) {
return return
} }
user := model.User{ user := model.User{
Username: req.Username, Id: req.Id,
} }
// Fill attributes // Fill attributes
model.DB.Unscoped().Where(&user).First(&user) model.DB.Unscoped().Where(&user).First(&user)

View File

@@ -78,6 +78,13 @@ func WeChatAuth(c *gin.Context) {
}) })
return return
} }
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else { } else {
if common.RegisterEnabled { if common.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)

View File

@@ -2,17 +2,18 @@ version: '3.4'
services: services:
new-api: new-api:
image: pengzhile/new-api:latest image: calciumion/new-api:latest
# build: .
container_name: new-api container_name: new-api
restart: always restart: always
command: --log-dir /app/logs command: --log-dir /app/logs
ports: ports:
- "3000:3000" - "3000:3000"
volumes: volumes:
- ./data/new-api:/data - ./data:/data
- ./logs:/app/logs - ./logs:/app/logs
environment: environment:
- SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库 - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串 - SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
@@ -22,22 +23,13 @@ services:
depends_on: depends_on:
- redis - redis
- db healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
timeout: 10s
retries: 3
redis: redis:
image: redis:latest image: redis:latest
container_name: redis container_name: redis
restart: always restart: always
db:
image: mysql:8.2.0
container_name: mysql
restart: always
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: newapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: new-api # 自动创建数据库

6
dto/playground.go Normal file
View File

@@ -0,0 +1,6 @@
package dto
type PlayGroundRequest struct {
Model string `json:"model,omitempty"`
Group string `json:"group,omitempty"`
}

View File

@@ -2,38 +2,38 @@ package dto
import "encoding/json" import "encoding/json"
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
BestOf int `json:"best_of,omitempty"` Stream bool `json:"stream,omitempty"`
Echo bool `json:"echo,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Stream bool `json:"stream,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
Suffix string `json:"suffix,omitempty"` Temperature float64 `json:"temperature,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` TopP float64 `json:"top_p,omitempty"`
Temperature float64 `json:"temperature,omitempty"` TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"` Stop any `json:"stop,omitempty"`
TopK int `json:"top_k,omitempty"` N int `json:"n,omitempty"`
Stop any `json:"stop,omitempty"` Input any `json:"input,omitempty"`
N int `json:"n,omitempty"` Instruction string `json:"instruction,omitempty"`
Input any `json:"input,omitempty"` Size string `json:"size,omitempty"`
Instruction string `json:"instruction,omitempty"` Functions any `json:"functions,omitempty"`
Size string `json:"size,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
Functions any `json:"functions,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` ResponseFormat any `json:"response_format,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` EncodingFormat any `json:"encoding_format,omitempty"`
ResponseFormat any `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"`
Seed float64 `json:"seed,omitempty"` Tools []ToolCall `json:"tools,omitempty"`
Tools []ToolCall `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` User string `json:"user,omitempty"`
User string `json:"user,omitempty"` LogProbs bool `json:"logprobs,omitempty"`
LogitBias any `json:"logit_bias,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"`
LogProbs any `json:"logprobs,omitempty"` Dimensions int `json:"dimensions,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
ParallelToolCalls bool `json:"parallel_Tool_Calls,omitempty"`
EncodingFormat string `json:"encoding_format,omitempty"`
} }
type OpenAITools struct { type OpenAITools struct {

View File

@@ -34,6 +34,7 @@ type OpenAITextResponseChoice struct {
type OpenAITextResponse struct { type OpenAITextResponse struct {
Id string `json:"id"` Id string `json:"id"`
Model string `json:"model"`
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"` Choices []OpenAITextResponseChoice `json:"choices"`
@@ -41,9 +42,9 @@ type OpenAITextResponse struct {
} }
type OpenAIEmbeddingResponseItem struct { type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"` Object string `json:"object"`
Index int `json:"index"` Index int `json:"index"`
Embedding any `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
type OpenAIEmbeddingResponse struct { type OpenAIEmbeddingResponse struct {

5
go.mod
View File

@@ -6,6 +6,7 @@ go 1.21
toolchain go1.22.4 toolchain go1.22.4
require ( require (
github.com/Calcium-Ion/go-epay v0.0.2
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.26.1 github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/credentials v1.17.11
@@ -26,10 +27,8 @@ require (
github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0 github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible github.com/shirou/gopsutil v3.21.11+incompatible
github.com/stripe/stripe-go/v76 v76.21.0
golang.org/x/crypto v0.26.0 golang.org/x/crypto v0.26.0
golang.org/x/image v0.15.0 golang.org/x/image v0.15.0
golang.org/x/net v0.28.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2 gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3 gorm.io/driver/sqlite v1.4.3
@@ -69,6 +68,7 @@ require (
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
@@ -80,6 +80,7 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sync v0.8.0 // indirect golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.24.0 // indirect golang.org/x/sys v0.24.0 // indirect
golang.org/x/text v0.17.0 // indirect golang.org/x/text v0.17.0 // indirect

8
go.sum
View File

@@ -1,3 +1,5 @@
github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
@@ -134,6 +136,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -176,8 +180,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stripe/stripe-go/v76 v76.21.0 h1:O3GHImHS4oUI3qWMOClHN3zAQF5/oswS/NB7leV1fsU=
github.com/stripe/stripe-go/v76 v76.21.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@@ -203,7 +205,6 @@ golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSO
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -212,7 +213,6 @@ golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -7,7 +7,7 @@ all: build-frontend start-backend
build-frontend: build-frontend:
@echo "Building frontend..." @echo "Building frontend..."
@cd $(FRONTEND_DIR) && yarn install --network-timeout 1000000 && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) yarn build @cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
start-backend: start-backend:
@echo "Starting backend dev server..." @echo "Starting backend dev server..."

View File

@@ -10,13 +10,23 @@ import (
"strings" "strings"
) )
func validUserInfo(username string, role int) bool {
// check username is empty
if strings.TrimSpace(username) == "" {
return false
}
if !common.IsValidateRole(role) {
return false
}
return true
}
func authHelper(c *gin.Context, minRole int) { func authHelper(c *gin.Context, minRole int) {
session := sessions.Default(c) session := sessions.Default(c)
username := session.Get("username") username := session.Get("username")
role := session.Get("role") role := session.Get("role")
id := session.Get("id") id := session.Get("id")
status := session.Get("status") status := session.Get("status")
linuxDoEnable := session.Get("linuxdo_enable")
useAccessToken := false useAccessToken := false
if username == nil { if username == nil {
// Check access token // Check access token
@@ -31,12 +41,19 @@ func authHelper(c *gin.Context, minRole int) {
} }
user := model.ValidateAccessToken(accessToken) user := model.ValidateAccessToken(accessToken)
if user != nil && user.Username != "" { if user != nil && user.Username != "" {
if !validUserInfo(user.Username, user.Role) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
// Token is valid // Token is valid
username = user.Username username = user.Username
role = user.Role role = user.Role
id = user.Id id = user.Id
status = user.Status status = user.Status
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
useAccessToken = true useAccessToken = true
} else { } else {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -85,14 +102,6 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort() c.Abort()
return return
} }
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户 LINUX DO 信任等级不足",
})
c.Abort()
return
}
if role.(int) < minRole { if role.(int) < minRole {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -101,9 +110,19 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort() c.Abort()
return return
} }
if !validUserInfo(username.(string), role.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权进行此操作,用户信息无效",
})
c.Abort()
return
}
c.Set("username", username) c.Set("username", username)
c.Set("role", role) c.Set("role", role)
c.Set("id", id) c.Set("id", id)
c.Set("group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
c.Next() c.Next()
} }
@@ -172,15 +191,6 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁") abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return return
} }
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !linuxDoEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
return
}
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
@@ -194,6 +204,8 @@ func TokenAuth() func(c *gin.Context) {
} else { } else {
c.Set("token_model_limit_enabled", false) c.Set("token_model_limit_enabled", false)
} }
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("specific_channel_id", parts[1]) c.Set("specific_channel_id", parts[1])

View File

@@ -22,6 +22,14 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
allowIpsMap := c.GetStringMap("allow_ips")
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
userId := c.GetInt("id") userId := c.GetInt("id")
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("specific_channel_id") channelId, ok := c.Get("specific_channel_id")
@@ -31,6 +39,20 @@ func Distribute() func(c *gin.Context) {
return return
} }
userGroup, _ := model.CacheGetUserGroup(userId) userGroup, _ := model.CacheGetUserGroup(userId)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if _, ok := common.GroupRatio[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
userGroup = tokenGroup
}
c.Set("group", userGroup) c.Set("group", userGroup)
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))

View File

@@ -36,6 +36,12 @@ func GetEnabledModels() []string {
return models return models
} }
func GetAllEnableAbilities() []Ability {
var abilities []Ability
DB.Find(&abilities, "enabled = ?", true)
return abilities
}
func getPriority(group string, model string, retry int) (int, error) { func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`" groupCol := "`group`"
trueVal := "1" trueVal := "1"

View File

@@ -205,30 +205,6 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err return userEnabled, err
} }
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
if !common.RedisEnabled {
return IsLinuxDoEnabled(userId)
}
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if linuxDoEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set linuxdo enabled error: " + err.Error())
}
return linuxDoEnabled, err
}
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex var channelSyncLock sync.RWMutex
@@ -293,8 +269,9 @@ func SyncChannelCache(frequency int) {
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") { if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*" model = "gpt-4-gizmo-*"
} else if strings.HasPrefix(model, "g-") { }
model = "g-*" if strings.HasPrefix(model, "gpt-4o-gizmo") {
model = "gpt-4o-gizmo-*"
} }
// if memory cache is disabled, get channel directly from database // if memory cache is disabled, get channel directly from database

View File

@@ -247,25 +247,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
return token return token
} }
func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) { func DeleteOldLog(targetTimestamp int64) (int64, error) {
var total int64 = 0 result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
for {
if nil != ctx.Err() {
return total, ctx.Err()
}
result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{})
if nil != result.Error {
return total, result.Error
}
total += result.RowsAffected
if result.RowsAffected < int64(limit) {
break
}
}
return total, nil
} }

View File

@@ -32,7 +32,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser, Role: common.RoleRootUser,
Status: common.UserStatusEnabled, Status: common.UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: common.GetUUID(), AccessToken: nil,
Quota: 100000000, Quota: 100000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)

View File

@@ -31,12 +31,10 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled) common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
@@ -62,19 +60,18 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = "" common.OptionMap["ServerAddress"] = ""
common.OptionMap["OutProxyUrl"] = "" common.OptionMap["WorkerUrl"] = constant.WorkerUrl
common.OptionMap["StripeApiSecret"] = common.StripeApiSecret common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey
common.OptionMap["StripeWebhookSecret"] = common.StripeWebhookSecret common.OptionMap["PayAddress"] = ""
common.OptionMap["StripePriceId"] = common.StripePriceId common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["PaymentEnabled"] = strconv.FormatBool(common.PaymentEnabled) common.OptionMap["EpayId"] = ""
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64) common.OptionMap["EpayKey"] = ""
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp) common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = constant.Chats2JsonString()
common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = "" common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["LinuxDoClientId"] = ""
common.OptionMap["LinuxDoClientSecret"] = ""
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
common.OptionMap["TelegramBotToken"] = "" common.OptionMap["TelegramBotToken"] = ""
common.OptionMap["TelegramBotName"] = "" common.OptionMap["TelegramBotName"] = ""
common.OptionMap["WeChatServerAddress"] = "" common.OptionMap["WeChatServerAddress"] = ""
@@ -90,6 +87,7 @@ func InitOptionMap() {
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["ChatLink"] = common.ChatLink
@@ -177,8 +175,6 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailVerificationEnabled = boolValue common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue common.GitHubOAuthEnabled = boolValue
case "LinuxDoOAuthEnabled":
common.LinuxDoOAuthEnabled = boolValue
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue common.WeChatAuthEnabled = boolValue
case "TelegramOAuthEnabled": case "TelegramOAuthEnabled":
@@ -187,8 +183,6 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileCheckEnabled = boolValue common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled": case "RegisterEnabled":
common.RegisterEnabled = boolValue common.RegisterEnabled = boolValue
case "UserSelfDeletionEnabled":
common.UserSelfDeletionEnabled = boolValue
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue common.EmailDomainRestrictionEnabled = boolValue
case "EmailAliasRestrictionEnabled": case "EmailAliasRestrictionEnabled":
@@ -248,33 +242,31 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPToken": case "SMTPToken":
common.SMTPToken = value common.SMTPToken = value
case "ServerAddress": case "ServerAddress":
common.ServerAddress = value constant.ServerAddress = value
case "OutProxyUrl": case "WorkerUrl":
common.OutProxyUrl = value constant.WorkerUrl = value
case "StripeApiSecret": case "WorkerValidKey":
common.StripeApiSecret = value constant.WorkerValidKey = value
case "StripeWebhookSecret": case "PayAddress":
common.StripeWebhookSecret = value constant.PayAddress = value
case "StripePriceId": case "Chats":
common.StripePriceId = value err = constant.UpdateChatsByJsonString(value)
case "PaymentEnabled": case "CustomCallbackAddress":
common.PaymentEnabled, _ = strconv.ParseBool(value) constant.CustomCallbackAddress = value
case "StripeUnitPrice": case "EpayId":
common.StripeUnitPrice, _ = strconv.ParseFloat(value, 64) constant.EpayId = value
case "EpayKey":
constant.EpayKey = value
case "Price":
constant.Price, _ = strconv.ParseFloat(value, 64)
case "MinTopUp": case "MinTopUp":
common.MinTopUp, _ = strconv.Atoi(value) constant.MinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio": case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value) err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId": case "GitHubClientId":
common.GitHubClientId = value common.GitHubClientId = value
case "GitHubClientSecret": case "GitHubClientSecret":
common.GitHubClientSecret = value common.GitHubClientSecret = value
case "LinuxDoClientId":
common.LinuxDoClientId = value
case "LinuxDoClientSecret":
common.LinuxDoClientSecret = value
case "LinuxDoMinLevel":
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
case "Footer": case "Footer":
common.Footer = value common.Footer = value
case "SystemName": case "SystemName":
@@ -315,6 +307,8 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = common.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = common.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio": case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value) err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice": case "ModelPrice":

View File

@@ -7,14 +7,13 @@ import (
) )
type Pricing struct { type Pricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"` ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"` QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"` ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"` ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"` OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"` CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"` EnableGroup []string `json:"enable_groups,omitempty"`
} }
var ( var (
@@ -23,40 +22,47 @@ var (
updatePricingLock sync.Mutex updatePricingLock sync.Mutex
) )
func GetPricing(group string) []Pricing { func GetPricing() []Pricing {
updatePricingLock.Lock() updatePricingLock.Lock()
defer updatePricingLock.Unlock() defer updatePricingLock.Unlock()
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing() updatePricing()
} }
if group != "" { //if group != "" {
userPricingMap := make([]Pricing, 0) // userPricingMap := make([]Pricing, 0)
models := GetGroupModels(group) // models := GetGroupModels(group)
for _, pricing := range pricingMap { // for _, pricing := range pricingMap {
if !common.StringsContains(models, pricing.ModelName) { // if !common.StringsContains(models, pricing.ModelName) {
pricing.Available = false // pricing.Available = false
} // }
userPricingMap = append(userPricingMap, pricing) // userPricingMap = append(userPricingMap, pricing)
} // }
return userPricingMap // return userPricingMap
} //}
return pricingMap return pricingMap
} }
func updatePricing() { func updatePricing() {
//modelRatios := common.GetModelRatios() //modelRatios := common.GetModelRatios()
enabledModels := GetEnabledModels() enableAbilities := GetAllEnableAbilities()
allModels := make(map[string]int) modelGroupsMap := make(map[string][]string)
for i, model := range enabledModels { for _, ability := range enableAbilities {
allModels[model] = i groups := modelGroupsMap[ability.Model]
if groups == nil {
groups = make([]string, 0)
}
if !common.StringsContains(groups, ability.Group) {
groups = append(groups, ability.Group)
}
modelGroupsMap[ability.Model] = groups
} }
pricingMap = make([]Pricing, 0) pricingMap = make([]Pricing, 0)
for model, _ := range allModels { for model, groups := range modelGroupsMap {
pricing := Pricing{ pricing := Pricing{
Available: true, ModelName: model,
ModelName: model, EnableGroup: groups,
} }
modelPrice, findPrice := common.GetModelPrice(model, false) modelPrice, findPrice := common.GetModelPrice(model, false)
if findPrice { if findPrice {

View File

@@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common"
"strconv" "strconv"
"strings" "strings"
) )
@@ -22,10 +24,34 @@ type Token struct {
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"` ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"` ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Group string `json:"group" gorm:"default:''"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
} }
func (token *Token) GetIpLimitsMap() map[string]any {
// delete empty spaces
//split with \n
ipLimitsMap := make(map[string]any)
if token.AllowIps == nil {
return ipLimitsMap
}
cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
if cleanIps == "" {
return ipLimitsMap
}
ips := strings.Split(cleanIps, "\n")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
ip = strings.ReplaceAll(ip, ",", "")
if common.IsIP(ip) {
ipLimitsMap[ip] = true
}
}
return ipLimitsMap
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
var tokens []*Token var tokens []*Token
var err error var err error
@@ -129,7 +155,8 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values // Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error { func (token *Token) Update() error {
var err error var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
return err return err
} }
@@ -231,51 +258,56 @@ func decreaseTokenQuota(id int, quota int) (err error) {
return err return err
} }
func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) { func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) {
if quota < 0 { if quota < 0 {
return 0, errors.New("quota 不能为负数!") return 0, errors.New("quota 不能为负数!")
} }
token, err := GetTokenById(tokenId) if !relayInfo.IsPlayground {
if err != nil { token, err := GetTokenById(relayInfo.TokenId)
return 0, err if err != nil {
return 0, err
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
} }
if !token.UnlimitedQuota && token.RemainQuota < quota { userQuota, err = GetUserQuota(relayInfo.UserId)
return 0, errors.New("令牌额度不足")
}
userQuota, err = GetUserQuota(token.UserId)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if userQuota < quota { if userQuota < quota {
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
} }
err = DecreaseTokenQuota(tokenId, quota) if !relayInfo.IsPlayground {
if err != nil { err = DecreaseTokenQuota(relayInfo.TokenId, quota)
return 0, err if err != nil {
return 0, err
}
} }
err = DecreaseUserQuota(token.UserId, quota) err = DecreaseUserQuota(relayInfo.UserId, quota)
return userQuota - quota, err return userQuota - quota, err
} }
func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 { if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota) err = DecreaseUserQuota(relayInfo.UserId, quota)
} else { } else {
err = IncreaseUserQuota(token.UserId, -quota) err = IncreaseUserQuota(relayInfo.UserId, -quota)
} }
if err != nil { if err != nil {
return err return err
} }
if quota > 0 { if !relayInfo.IsPlayground {
err = DecreaseTokenQuota(tokenId, quota) if quota > 0 {
} else { err = DecreaseTokenQuota(relayInfo.TokenId, quota)
err = IncreaseTokenQuota(tokenId, -quota) } else {
} err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
if err != nil { }
return err if err != nil {
return err
}
} }
if sendEmail { if sendEmail {
@@ -284,7 +316,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0 noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota { if quotaTooLow || noMoreQuota {
go func() { go func() {
email, err := GetUserEmail(token.UserId) email, err := GetUserEmail(relayInfo.UserId)
if err != nil { if err != nil {
common.SysError("failed to fetch user email: " + err.Error()) common.SysError("failed to fetch user email: " + err.Error())
} }
@@ -293,7 +325,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
prompt = "您的额度已用尽" prompt = "您的额度已用尽"
} }
if email != "" { if email != "" {
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress)
err = common.SendEmail(prompt, email, err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil { if err != nil {

View File

@@ -1,21 +1,13 @@
package model package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
)
type TopUp struct { type TopUp struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id" gorm:"index"`
Amount int `json:"amount"` Amount int `json:"amount"`
Money float64 `json:"money"` Money float64 `json:"money"`
TradeNo string `json:"trade_no" gorm:"unique"` TradeNo string `json:"trade_no"`
CreateTime int64 `json:"create_time"` CreateTime int64 `json:"create_time"`
CompleteTime int64 `json:"complete_time"` Status string `json:"status"`
Status string `json:"status"`
} }
func (topUp *TopUp) Insert() error { func (topUp *TopUp) Insert() error {
@@ -49,51 +41,3 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
} }
return topUp return topUp
} }
func Recharge(referenceId string, customerId string) (err error) {
if referenceId == "" {
return errors.New("未提供支付单号")
}
var quota float64
topUp := &TopUp{}
refCol := "`trade_no`"
if common.UsingPostgreSQL {
refCol = `"trade_no"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
if err != nil {
return errors.New("充值订单不存在")
}
if topUp.Status != common.TopUpStatusPending {
return errors.New("充值订单状态错误")
}
topUp.CompleteTime = common.GetTimestamp()
topUp.Status = common.TopUpStatusSuccess
err = tx.Save(topUp).Error
if err != nil {
return err
}
quota = topUp.Money * common.QuotaPerUnit
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
if err != nil {
return err
}
return nil
})
if err != nil {
return errors.New("充值失败," + err.Error())
}
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%d", common.LogQuotaF(quota), topUp.Amount))
return nil
}

View File

@@ -22,12 +22,10 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"` Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"` GitHubId string `json:"github_id" gorm:"column:github_id;index"`
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"` Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
@@ -37,10 +35,20 @@ type User struct {
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度 AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度 AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
StripeCustomer string `json:"stripe_customer" gorm:"column:stripe_customer;index"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
} }
func (user *User) GetAccessToken() string {
if user.AccessToken == nil {
return ""
}
return *user.AccessToken
}
func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) { func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User var user User
@@ -67,7 +75,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
func GetMaxUserId() int { func GetMaxUserId() int {
var user User var user User
DB.Unscoped().Last(&user) DB.Last(&user)
return user.Id return user.Id
} }
@@ -122,20 +130,6 @@ func GetUserById(id int, selectAll bool) (*User, error) {
return &user, err return &user, err
} }
func GetUserByIdUnscoped(id int, selectAll bool) (*User, error) {
if id == 0 {
return nil, errors.New("id 为空!")
}
user := User{Id: id}
var err error = nil
if selectAll {
err = DB.Unscoped().First(&user, "id = ?", id).Error
} else {
err = DB.Unscoped().Omit("password").First(&user, "id = ?", id).Error
}
return &user, err
}
func GetUserIdByAffCode(affCode string) (int, error) { func GetUserIdByAffCode(affCode string) (int, error) {
if affCode == "" { if affCode == "" {
return 0, errors.New("affCode 为空!") return 0, errors.New("affCode 为空!")
@@ -218,7 +212,7 @@ func (user *User) Insert(inviterId int) error {
} }
} }
user.Quota = common.QuotaForNewUser user.Quota = common.QuotaForNewUser
user.AccessToken = common.GetUUID() //user.SetAccessToken(common.GetUUID())
user.AffCode = common.GetRandomString(4) user.AffCode = common.GetRandomString(4)
result := DB.Create(user) result := DB.Create(user)
if result.Error != nil { if result.Error != nil {
@@ -312,11 +306,12 @@ func (user *User) ValidateAndFill() (err error) {
// that means if your fields value is 0, '', false or other zero values, // that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions // it wont be used to build query conditions
password := user.Password password := user.Password
if user.Username == "" || password == "" { username := strings.TrimSpace(user.Username)
if username == "" || password == "" {
return errors.New("用户名或密码为空") return errors.New("用户名或密码为空")
} }
// find buy username or email // find buy username or email
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user) DB.Where("username = ? OR email = ?", username, username).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password) okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled { if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁") return errors.New("用户名或密码错误,或用户已被封禁")
@@ -348,14 +343,6 @@ func (user *User) FillUserByGitHubId() error {
return nil return nil
} }
func (user *User) FillUserByLinuxDoId() error {
if user.LinuxDoId == "" {
return errors.New("LINUX DO id 为空!")
}
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error { func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" { if user.WeChatId == "" {
return errors.New("WeChat id 为空!") return errors.New("WeChat id 为空!")
@@ -364,14 +351,6 @@ func (user *User) FillUserByWeChatId() error {
return nil return nil
} }
func (user *User) FillUserByUsername() error {
if user.Username == "" {
return errors.New("username 为空!")
}
DB.Where(User{Username: user.Username}).First(user)
return nil
}
func (user *User) FillUserByTelegramId() error { func (user *User) FillUserByTelegramId() error {
if user.TelegramId == "" { if user.TelegramId == "" {
return errors.New("Telegram id 为空!") return errors.New("Telegram id 为空!")
@@ -384,27 +363,19 @@ func (user *User) FillUserByTelegramId() error {
} }
func IsEmailAlreadyTaken(email string) bool { func IsEmailAlreadyTaken(email string) bool {
return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1 return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1
} }
func IsWeChatIdAlreadyTaken(wechatId string) bool { func IsWeChatIdAlreadyTaken(wechatId string) bool {
return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
} }
func IsGitHubIdAlreadyTaken(githubId string) bool { func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
return DB.Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
} }
func IsTelegramIdAlreadyTaken(telegramId string) bool { func IsTelegramIdAlreadyTaken(telegramId string) bool {
return DB.Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1 return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
} }
func ResetUserPasswordByEmail(email string, password string) error { func ResetUserPasswordByEmail(email string, password string) error {
@@ -444,18 +415,6 @@ func IsUserEnabled(userId int) (bool, error) {
return user.Status == common.UserStatusEnabled, nil return user.Status == common.UserStatusEnabled, nil
} }
func IsLinuxDoEnabled(userId int) (bool, error) {
if userId == 0 {
return false, errors.New("user id is empty")
}
var user User
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
if err != nil {
return false, err
}
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
}
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
if token == "" { if token == "" {
return nil return nil

View File

@@ -36,8 +36,8 @@ type AliEmbeddingRequest struct {
} }
type AliEmbedding struct { type AliEmbedding struct {
Embedding any `json:"embedding"` Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"` TextIndex int `json:"text_index"`
} }
type AliEmbeddingResponse struct { type AliEmbeddingResponse struct {

View File

@@ -105,7 +105,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
for _, data := range response.Output.Results { for _, data := range response.Output.Results {
var b64Json string var b64Json string
if responseFormat == "b64_json" { if responseFormat == "b64_json" {
_, b64, err := common.GetImageFromUrl(data.Url) _, b64, err := service.GetImageFromUrl(data.Url)
if err != nil { if err != nil {
common.LogError(c, "get_image_data_failed: "+err.Error()) common.LogError(c, "get_image_data_failed: "+err.Error())
continue continue

View File

@@ -8,7 +8,6 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"strings"
) )
const ( const (
@@ -31,11 +30,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") { a.RequestMode = RequestModeMessage
a.RequestMode = RequestModeMessage
} else {
a.RequestMode = RequestModeCompletion
}
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -53,11 +48,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
var claudeReq *claude.ClaudeRequest var claudeReq *claude.ClaudeRequest
var err error var err error
if a.RequestMode == RequestModeCompletion { claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
claudeReq = claude.RequestOpenAI2ClaudeComplete(*request)
} else {
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
}
c.Set("request_model", request.Model) c.Set("request_model", request.Model)
c.Set("converted_request", claudeReq) c.Set("converted_request", claudeReq)
return claudeReq, err return claudeReq, err

View File

@@ -50,9 +50,9 @@ type BaiduEmbeddingRequest struct {
} }
type BaiduEmbeddingData struct { type BaiduEmbeddingData struct {
Object string `json:"object"` Object string `json:"object"`
Embedding any `json:"embedding"` Embedding []float64 `json:"embedding"`
Index int `json:"index"` Index int `json:"index"`
} }
type BaiduEmbeddingResponse struct { type BaiduEmbeddingResponse struct {

View File

@@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -12,6 +11,8 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
func stopReasonClaude2OpenAI(reason string) string { func stopReasonClaude2OpenAI(reason string) string {
@@ -108,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
} }
} }
formatMessages := make([]dto.Message, 0) formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message lastMessage := dto.Message{
Role: "tool",
}
for i, message := range textRequest.Messages { for i, message := range textRequest.Messages {
//if message.Role == "system" {
// if i != 0 {
// message.Role = "user"
// }
//}
if message.Role == "" { if message.Role == "" {
textRequest.Messages[i].Role = "user" textRequest.Messages[i].Role = "user"
} }
@@ -122,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.Content,
} }
if lastMessage != nil && lastMessage.Role == message.Role { if message.Role == "tool" {
fmtMessage.ToolCallId = message.ToolCallId
}
if message.Role == "assistant" && message.ToolCalls != nil {
fmtMessage.ToolCalls = message.ToolCalls
}
if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
if lastMessage.IsStringContent() && message.IsStringContent() { if lastMessage.IsStringContent() && message.IsStringContent() {
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\"")) content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
fmtMessage.Content = content fmtMessage.Content = content
@@ -135,7 +139,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
fmtMessage.Content = content fmtMessage.Content = content
} }
formatMessages = append(formatMessages, fmtMessage) formatMessages = append(formatMessages, fmtMessage)
lastMessage = &textRequest.Messages[i] lastMessage = fmtMessage
} }
claudeMessages := make([]ClaudeMessage, 0) claudeMessages := make([]ClaudeMessage, 0)
@@ -174,7 +178,35 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
claudeMessage := ClaudeMessage{ claudeMessage := ClaudeMessage{
Role: message.Role, Role: message.Role,
} }
if message.IsStringContent() { if message.Role == "tool" {
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
lastMessage := claudeMessages[len(claudeMessages)-1]
if content, ok := lastMessage.Content.(string); ok {
lastMessage.Content = []ClaudeMediaMessage{
{
Type: "text",
Text: content,
},
}
}
lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
})
claudeMessages[len(claudeMessages)-1] = lastMessage
continue
} else {
claudeMessage.Role = "user"
claudeMessage.Content = []ClaudeMediaMessage{
{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
},
}
}
} else if message.IsStringContent() && message.ToolCalls == nil {
claudeMessage.Content = message.StringContent() claudeMessage.Content = message.StringContent()
} else { } else {
claudeMediaMessages := make([]ClaudeMediaMessage, 0) claudeMediaMessages := make([]ClaudeMediaMessage, 0)
@@ -193,11 +225,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
// 判断是否是url // 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据 // 是url获取图片的类型和base64编码的数据
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url) mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data claudeMediaMessage.Source.Data = data
} else { } else {
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url) _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -207,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
} }
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
} }
if message.ToolCalls != nil {
for _, tc := range message.ToolCalls.([]interface{}) {
toolCallJSON, _ := json.Marshal(tc)
var toolCall dto.ToolCall
err := json.Unmarshal(toolCallJSON, &toolCall)
if err != nil {
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
continue
}
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
Type: "tool_use",
Id: toolCall.ID,
Name: toolCall.Function.Name,
Input: inputObj,
})
}
}
claudeMessage.Content = claudeMediaMessages claudeMessage.Content = claudeMediaMessages
} }
claudeMessages = append(claudeMessages, claudeMessage) claudeMessages = append(claudeMessages, claudeMessage)
@@ -341,6 +395,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
if len(tools) > 0 { if len(tools) > 0 {
choice.Message.ToolCalls = tools choice.Message.ToolCalls = tools
} }
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice) choices = append(choices, choice)
fullTextResponse.Choices = choices fullTextResponse.Choices = choices
return &fullTextResponse return &fullTextResponse

View File

@@ -8,7 +8,7 @@ type CohereRequest struct {
Message string `json:"message"` Message string `json:"message"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
SafetyMode string `json:"safety_mode"` SafetyMode string `json:"safety_mode,omitempty"`
} }
type ChatHistory struct { type ChatHistory struct {

View File

@@ -22,7 +22,9 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
Message: "", Message: "",
Stream: textRequest.Stream, Stream: textRequest.Stream,
MaxTokens: textRequest.GetMaxTokens(), MaxTokens: textRequest.GetMaxTokens(),
SafetyMode: common.CohereSafetySetting, }
if common.CohereSafetySetting != "NONE" {
cohereReq.SafetyMode = common.CohereSafetySetting
} }
if cohereReq.MaxTokens == 0 { if cohereReq.MaxTokens == 0 {
cohereReq.MaxTokens = 4000 cohereReq.MaxTokens = 4000

View File

@@ -6,7 +6,7 @@ const (
var ModelList = []string{ var ModelList = []string{
"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra", "gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827",
} }
var ChannelName = "google gemini" var ChannelName = "google gemini"

View File

@@ -86,7 +86,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
// 判断是否是url // 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") { if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据 // 是url获取图片的类型和base64编码的数据
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url) mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &GeminiInlineData{
MimeType: mimeType, MimeType: mimeType,
@@ -94,7 +94,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
}, },
}) })
} else { } else {
_, format, base64String, err := common.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url) _, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil { if err != nil {
continue continue
} }

View File

@@ -17,11 +17,25 @@ type OllamaRequest struct {
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
} }
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type OllamaEmbeddingRequest struct { type OllamaEmbeddingRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Prompt any `json:"prompt,omitempty"` Input []string `json:"input"`
Options *Options `json:"options,omitempty"`
} }
type OllamaEmbeddingResponse struct { type OllamaEmbeddingResponse struct {
Embedding any `json:"embedding,omitempty"` Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding []float64 `json:"embedding,omitempty"`
} }

View File

@@ -9,7 +9,6 @@ import (
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/service" "one-api/service"
"strings"
) )
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
@@ -45,8 +44,15 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{ return &OllamaEmbeddingRequest{
Model: request.Model, Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "), Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
} }
} }
@@ -64,6 +70,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if ollamaEmbeddingResponse.Error != "" {
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
}
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{ data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: ollamaEmbeddingResponse.Embedding, Embedding: ollamaEmbeddingResponse.Embedding,

View File

@@ -78,6 +78,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if info.ChannelType != common.ChannelTypeOpenAI { if info.ChannelType != common.ChannelTypeOpenAI {
request.StreamOptions = nil request.StreamOptions = nil
} }
if strings.HasPrefix(request.Model, "o1-") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
}
return request, nil return request, nil
} }

View File

@@ -1,22 +1,24 @@
package openai package openai
var ModelList = []string{ var ModelList = []string{
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", "gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-instruct", "gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-32k", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"chatgpt-4o-latest", "chatgpt-4o-latest",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18", "gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001",
"text-moderation-latest", "text-moderation-stable", "text-moderation-latest", "text-moderation-stable",
"text-davinci-edit-001", "text-davinci-edit-001",
"davinci-002", "babbage-002", "davinci-002", "babbage-002",
"dall-e-2", "dall-e-3", "dall-e-3",
"whisper-1", "whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
} }

View File

@@ -20,6 +20,7 @@ type RelayInfo struct {
setFirstResponse bool setFirstResponse bool
ApiType int ApiType int
IsStream bool IsStream bool
IsPlayground bool
RelayMode int RelayMode int
UpstreamModelName string UpstreamModelName string
OriginModelName string OriginModelName string
@@ -65,6 +66,11 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"), Organization: c.GetString("channel_organization"),
} }
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
info.RequestURLPath = "/v1" + info.RequestURLPath
}
if info.BaseUrl == "" { if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType] info.BaseUrl = common.ChannelBaseURLs[channelType]
} }
@@ -146,3 +152,20 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
} }
return info return info
} }
func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
return &RelayInfo{
ChannelType: info.ChannelType,
ChannelId: info.ChannelId,
TokenId: info.TokenId,
UserId: info.UserId,
Group: info.Group,
StartTime: info.StartTime,
ApiType: info.ApiType,
RelayMode: info.RelayMode,
UpstreamModelName: info.UpstreamModelName,
RequestURLPath: info.RequestURLPath,
ApiKey: info.ApiKey,
BaseUrl: info.BaseUrl,
}
}

View File

@@ -42,7 +42,7 @@ const (
func Path2RelayMode(path string) int { func Path2RelayMode(path string) int {
relayMode := RelayModeUnknown relayMode := RelayModeUnknown
if strings.HasPrefix(path, "/v1/chat/completions") { if strings.HasPrefix(path, "/v1/chat/completions") || strings.HasPrefix(path, "/pg/chat/completions") {
relayMode = RelayModeChatCompletions relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(path, "/v1/completions") { } else if strings.HasPrefix(path, "/v1/completions") {
relayMode = RelayModeCompletions relayMode = RelayModeCompletions

View File

@@ -87,7 +87,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
preConsumedQuota = 0 preConsumedQuota = 0
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota) userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
@@ -126,7 +126,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
if resp != nil { if resp != nil {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp) openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
@@ -136,7 +136,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr

View File

@@ -12,6 +12,7 @@ import (
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"strconv" "strconv"
@@ -30,7 +31,7 @@ func RelayMidjourneyImage(c *gin.Context) {
}) })
return return
} }
resp, err := common.ProxiedHttpGet(midjourneyTask.ImageUrl, common.OutProxyUrl) resp, err := http.Get(midjourneyTask.ImageUrl)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed", "error": "http_get_image_failed",
@@ -111,7 +112,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = "" midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled { if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" { if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
} }
@@ -146,6 +147,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
var swapFaceRequest dto.SwapFaceRequest var swapFaceRequest dto.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest) err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil { if err != nil {
@@ -191,7 +193,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
} }
defer func(ctx context.Context) { defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }
@@ -356,6 +358,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
consumeQuota := true consumeQuota := true
var midjRequest dto.MidjourneyRequest var midjRequest dto.MidjourneyRequest
err := common.UnmarshalBodyReusable(c, &midjRequest) err := common.UnmarshalBodyReusable(c, &midjRequest)
@@ -495,7 +498,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 { if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }

View File

@@ -76,6 +76,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
} }
// map model name // map model name
isModelMapped := false
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
//isModelMapped := false //isModelMapped := false
if modelMapping != "" && modelMapping != "{}" { if modelMapping != "" && modelMapping != "{}" {
@@ -85,6 +86,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
} }
if modelMap[textRequest.Model] != "" { if modelMap[textRequest.Model] != "" {
isModelMapped = true
textRequest.Model = modelMap[textRequest.Model] textRequest.Model = modelMap[textRequest.Model]
// set upstream model name // set upstream model name
//isModelMapped = true //isModelMapped = true
@@ -159,15 +161,23 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
adaptor.Init(relayInfo) adaptor.Init(relayInfo)
var requestBody io.Reader var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest) if relayInfo.ChannelType == common.ChannelTypeOpenAI && !isModelMapped {
if err != nil { body, err := common.GetRequestBody(c)
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
} }
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
@@ -178,7 +188,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
if resp != nil { if resp != nil {
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp) openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
@@ -188,7 +198,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr
@@ -205,15 +215,6 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model) promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
case relayconstant.RelayModeCompletions: case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
prompts := textRequest.Prompt
switch v := prompts.(type) {
case string:
prompts = v + textRequest.Suffix
case []string:
prompts = append(v, textRequest.Suffix)
}
promptTokens, err = service.CountTokenInput(prompts, textRequest.Model)
case relayconstant.RelayModeModerations: case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
@@ -275,7 +276,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
} }
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota) userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil { if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
@@ -283,11 +284,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return preConsumedQuota, userQuota, nil return preConsumedQuota, userQuota, nil
} }
func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsumedQuota int) { func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
if preConsumedQuota != 0 { if preConsumedQuota != 0 {
go func(ctx context.Context) { go func(ctx context.Context) {
// return pre-consumed quota // return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false) err := model.PostConsumeTokenQuota(relayInfo, userQuota, -preConsumedQuota, 0, false)
if err != nil { if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error()) common.SysError("error return pre-consumed quota: " + err.Error())
} }
@@ -345,7 +346,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
//} //}
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 { if quotaDelta != 0 {
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true) err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil { if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error()) common.LogError(ctx, "error consuming token remain quota: "+err.Error())
} }
@@ -362,8 +363,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
if strings.HasPrefix(logModel, "gpt-4-gizmo") { if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*" logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", modelName) logContent += fmt.Sprintf(",模型 %s", modelName)
} else if strings.HasPrefix(logModel, "g-") { }
logModel = "g-*" if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
logModel = "gpt-4o-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", modelName) logContent += fmt.Sprintf(",模型 %s", modelName)
} }
if extraContent != "" { if extraContent != "" {

View File

@@ -101,7 +101,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
} }
if resp != nil { if resp != nil {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp) openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
@@ -111,7 +111,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr) service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr return openaiErr

View File

@@ -111,7 +111,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
defer func(ctx context.Context) { defer func(ctx context.Context) {
// release quota // release quota
if relayInfo.ConsumeQuota && taskErr == nil { if relayInfo.ConsumeQuota && taskErr == nil {
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true)
err := model.PostConsumeTokenQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }

View File

@@ -18,14 +18,13 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus) apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout) apiRouter.GET("/about", controller.GetAbout)
apiRouter.GET("/midjourney", controller.GetMidjourney) //apiRouter.GET("/midjourney", controller.GetMidjourney)
apiRouter.GET("/home_page_content", controller.GetHomePageContent) apiRouter.GET("/home_page_content", controller.GetHomePageContent)
apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing) apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxDoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
@@ -33,26 +32,27 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin) apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind) apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
userRoute := apiRouter.Group("/user") userRoute := apiRouter.Group("/user")
{ {
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login) userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog) //userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout) userRoute.GET("/logout", controller.Logout)
userRoute.GET("/epay/notify", controller.EpayNotify)
userRoute.GET("/groups", controller.GetUserGroups)
selfRoute := userRoute.Group("/") selfRoute := userRoute.Group("/")
selfRoute.Use(middleware.UserAuth()) selfRoute.Use(middleware.UserAuth())
{ {
selfRoute.GET("/self/groups", controller.GetUserGroups)
selfRoute.GET("/self", controller.GetSelf) selfRoute.GET("/self", controller.GetSelf)
selfRoute.GET("/models", controller.GetUserModels) selfRoute.GET("/models", controller.GetUserModels)
selfRoute.PUT("/self", controller.UpdateSelf) selfRoute.PUT("/self", controller.UpdateSelf)
selfRoute.DELETE("/self", controller.DeleteSelf) selfRoute.DELETE("/self", controller.DeleteSelf)
selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode) selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp) selfRoute.POST("/topup", controller.TopUp)
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestPayLink) selfRoute.POST("/pay", controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount) selfRoute.POST("/amount", controller.RequestAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota) selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
} }
@@ -66,7 +66,7 @@ func SetApiRouter(router *gin.Engine) {
adminRoute.POST("/", controller.CreateUser) adminRoute.POST("/", controller.CreateUser)
adminRoute.POST("/manage", controller.ManageUser) adminRoute.POST("/manage", controller.ManageUser)
adminRoute.PUT("/", controller.UpdateUser) adminRoute.PUT("/", controller.UpdateUser)
adminRoute.DELETE("/:id", controller.HardDeleteUser) adminRoute.DELETE("/:id", controller.DeleteUser)
} }
} }
optionRoute := apiRouter.Group("/option") optionRoute := apiRouter.Group("/option")

View File

@@ -16,6 +16,11 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter.GET("", controller.ListModels) modelsRouter.GET("", controller.ListModels)
modelsRouter.GET("/:model", controller.RetrieveModel) modelsRouter.GET("/:model", controller.RetrieveModel)
} }
playgroundRouter := router.Group("/pg")
playgroundRouter.Use(middleware.UserAuth())
{
playgroundRouter.POST("/chat/completions", controller.Playground)
}
relayV1Router := router.Group("/v1") relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{ {

View File

@@ -73,6 +73,15 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
} else if strings.HasPrefix(err.Error.Message, "Permission denied") { } else if strings.HasPrefix(err.Error.Message, "Permission denied") {
return true return true
} }
if strings.Contains(err.Error.Message, "The security token included in the request is invalid") { // anthropic
return true
} else if strings.Contains(err.Error.Message, "Operation not allowed") {
return true
} else if strings.Contains(err.Error.Message, "Your account is not authorized") {
return true
}
return false return false
} }

12
service/epay.go Normal file
View File

@@ -0,0 +1,12 @@
package service
import (
"one-api/constant"
)
func GetCallbackAddress() string {
if constant.CustomCallbackAddress == "" {
return constant.ServerAddress
}
return constant.CustomCallbackAddress
}

View File

@@ -1,4 +1,4 @@
package common package service
import ( import (
"bytes" "bytes"
@@ -8,6 +8,7 @@ import (
"golang.org/x/image/webp" "golang.org/x/image/webp"
"image" "image"
"io" "io"
"one-api/common"
"strings" "strings"
) )
@@ -30,24 +31,9 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
return config, format, base64String, err return config, format, base64String, err
} }
func IsImageUrl(url string) (bool, error) {
resp, err := ProxiedHttpHead(url, OutProxyUrl)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
// GetImageFromUrl 获取图片的类型和base64编码的数据 // GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) { func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url) resp, err := DoImageRequest(url)
if !isImage {
return
}
resp, err := ProxiedHttpGet(url, OutProxyUrl)
if err != nil { if err != nil {
return return
} }
@@ -66,9 +52,9 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
} }
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := ProxiedHttpGet(imageUrl, OutProxyUrl) response, err := DoImageRequest(imageUrl)
if err != nil { if err != nil {
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err return image.Config{}, "", err
} }
defer response.Body.Close() defer response.Body.Close()
@@ -80,7 +66,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
var readData []byte var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
// 从response.Body读取更多的数据直到达到当前的限制 // 从response.Body读取更多的数据直到达到当前的限制
additionalData := make([]byte, limit-int64(len(readData))) additionalData := make([]byte, limit-int64(len(readData)))
@@ -106,11 +92,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
config, format, err := image.DecodeConfig(reader) config, format, err := image.DecodeConfig(reader)
if err != nil { if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
SysLog(err.Error()) common.SysLog(err.Error())
config, err = webp.DecodeConfig(reader) config, err = webp.DecodeConfig(reader)
if err != nil { if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
SysLog(err.Error()) common.SysLog(err.Error())
} }
format = "webp" format = "webp"
} }

View File

@@ -18,6 +18,7 @@ import (
// tokenEncoderMap won't grow after initialization // tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken var defaultTokenEncoder *tiktoken.Tiktoken
var cl200kTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() { func InitTokenEncoders() {
common.SysLog("initializing token encoders") common.SysLog("initializing token encoders")
@@ -30,20 +31,19 @@ func InitTokenEncoders() {
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
} }
cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
} }
for model, _ := range common.GetDefaultModelRatioMap() { for model, _ := range common.GetDefaultModelRatioMap() {
if strings.HasPrefix(model, "gpt-3.5") { if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4o") {
tokenEncoderMap[model] = gpt4oTokenEncoder
} else if strings.HasPrefix(model, "chatgpt-4o") {
tokenEncoderMap[model] = gpt4oTokenEncoder
} else if strings.HasPrefix(model, "gpt-4") { } else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder if strings.HasPrefix(model, "gpt-4o") {
tokenEncoderMap[model] = cl200kTokenEncoder
} else {
tokenEncoderMap[model] = gpt4TokenEncoder
}
} else { } else {
tokenEncoderMap[model] = nil tokenEncoderMap[model] = nil
} }
@@ -51,6 +51,13 @@ func InitTokenEncoders() {
common.SysLog("token encoders initialized") common.SysLog("token encoders initialized")
} }
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") {
return cl200kTokenEncoder
}
return defaultTokenEncoder
}
func getTokenEncoder(model string) *tiktoken.Tiktoken { func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, ok := tokenEncoderMap[model] tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil { if ok && tokenEncoder != nil {
@@ -61,12 +68,13 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, err := tiktoken.EncodingForModel(model) tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder tokenEncoder = getModelDefaultTokenEncoder(model)
} }
tokenEncoderMap[model] = tokenEncoder tokenEncoderMap[model] = tokenEncoder
return tokenEncoder return tokenEncoder
} }
return defaultTokenEncoder // 如果model不在tokenEncoderMap中直接返回默认的tokenEncoder
return getModelDefaultTokenEncoder(model)
} }
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
@@ -103,11 +111,10 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
var err error var err error
var format string var format string
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url)) config, format, err = DecodeUrlImageData(imageUrl.Url)
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
} else { } else {
common.SysLog(fmt.Sprintf("decoding image")) common.SysLog(fmt.Sprintf("decoding image"))
config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url) config, format, _, err = DecodeBase64ImageData(imageUrl.Url)
} }
if err != nil { if err != nil {
return 0, err return 0, err

26
service/worker.go Normal file
View File

@@ -0,0 +1,26 @@
package service
import (
"bytes"
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"strings"
)
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
if constant.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
workerUrl := constant.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
}
// post request to worker
data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
} else {
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
return http.Get(originUrl)
}
}

5
web/.gitignore vendored
View File

@@ -10,7 +10,6 @@
# production # production
/build /build
/dist
# misc # misc
.DS_Store .DS_Store
@@ -22,4 +21,6 @@
npm-debug.log* npm-debug.log*
yarn-debug.log* yarn-debug.log*
yarn-error.log* yarn-error.log*
.idea/ .idea
package-lock.json
yarn.lock

View File

@@ -1 +1 @@
module.exports = require('@so1ve/prettier-config'); module.exports = require("@so1ve/prettier-config");

View File

@@ -4,8 +4,8 @@
"private": true, "private": true,
"type": "module", "type": "module",
"dependencies": { "dependencies": {
"@douyinfe/semi-icons": "^2.46.1", "@douyinfe/semi-icons": "^2.63.1",
"@douyinfe/semi-ui": "^2.55.3", "@douyinfe/semi-ui": "^2.63.1",
"@visactor/react-vchart": "~1.8.8", "@visactor/react-vchart": "~1.8.8",
"@visactor/vchart": "~1.8.8", "@visactor/vchart": "~1.8.8",
"@visactor/vchart-semi-theme": "~1.8.8", "@visactor/vchart-semi-theme": "~1.8.8",
@@ -22,7 +22,8 @@
"react-toastify": "^9.0.8", "react-toastify": "^9.0.8",
"react-turnstile": "^1.0.5", "react-turnstile": "^1.0.5",
"semantic-ui-offline": "^2.5.0", "semantic-ui-offline": "^2.5.0",
"semantic-ui-react": "^2.1.3" "semantic-ui-react": "^2.1.3",
"sse": "github:mpetazzoni/sse.js"
}, },
"scripts": { "scripts": {
"dev": "vite", "dev": "vite",

5282
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.2 KiB

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.9 KiB

After

Width:  |  Height:  |  Size: 7.6 KiB

View File

@@ -11,7 +11,6 @@ import EditUser from './pages/User/EditUser';
import { getLogo, getSystemName } from './helpers'; import { getLogo, getSystemName } from './helpers';
import PasswordResetForm from './components/PasswordResetForm'; import PasswordResetForm from './components/PasswordResetForm';
import GitHubOAuth from './components/GitHubOAuth'; import GitHubOAuth from './components/GitHubOAuth';
import LinuxDoOAuth from './components/LinuxDoOAuth';
import PasswordResetConfirm from './components/PasswordResetConfirm'; import PasswordResetConfirm from './components/PasswordResetConfirm';
import { UserContext } from './context/User'; import { UserContext } from './context/User';
import Channel from './pages/Channel'; import Channel from './pages/Channel';
@@ -21,11 +20,12 @@ import Redemption from './pages/Redemption';
import TopUp from './pages/TopUp'; import TopUp from './pages/TopUp';
import Log from './pages/Log'; import Log from './pages/Log';
import Chat from './pages/Chat'; import Chat from './pages/Chat';
import Chat2Link from './pages/Chat2Link';
import { Layout } from '@douyinfe/semi-ui'; import { Layout } from '@douyinfe/semi-ui';
import Midjourney from './pages/Midjourney'; import Midjourney from './pages/Midjourney';
import Pricing from './pages/Pricing/index.js'; import Pricing from './pages/Pricing/index.js';
import Task from './pages/Task/index.js'; import Task from "./pages/Task/index.js";
// import Detail from './pages/Detail'; import Playground from './components/Playground.js';
const Home = lazy(() => import('./pages/Home')); const Home = lazy(() => import('./pages/Home'));
const Detail = lazy(() => import('./pages/Detail')); const Detail = lazy(() => import('./pages/Detail'));
@@ -59,215 +59,224 @@ function App() {
}, []); }, []);
return ( return (
<Layout> <>
<Layout.Content> <Routes>
<Routes> <Route
<Route path='/'
path='/' element={
element={ <Suspense fallback={<Loading></Loading>}>
<Home />
</Suspense>
}
/>
<Route
path='/channel'
element={
<PrivateRoute>
<Channel />
</PrivateRoute>
}
/>
<Route
path='/channel/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
</Suspense>
}
/>
<Route
path='/channel/add'
element={
<Suspense fallback={<Loading></Loading>}>
<EditChannel />
</Suspense>
}
/>
<Route
path='/token'
element={
<PrivateRoute>
<Token />
</PrivateRoute>
}
/>
<Route
path='/playground'
element={
<PrivateRoute>
<Playground />
</PrivateRoute>
}
/>
<Route
path='/redemption'
element={
<PrivateRoute>
<Redemption />
</PrivateRoute>
}
/>
<Route
path='/user'
element={
<PrivateRoute>
<User />
</PrivateRoute>
}
/>
<Route
path='/user/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}>
<EditUser />
</Suspense>
}
/>
<Route
path='/user/edit'
element={
<Suspense fallback={<Loading></Loading>}>
<EditUser />
</Suspense>
}
/>
<Route
path='/user/reset'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetConfirm />
</Suspense>
}
/>
<Route
path='/login'
element={
<Suspense fallback={<Loading></Loading>}>
<LoginForm />
</Suspense>
}
/>
<Route
path='/register'
element={
<Suspense fallback={<Loading></Loading>}>
<RegisterForm />
</Suspense>
}
/>
<Route
path='/reset'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetForm />
</Suspense>
}
/>
<Route
path='/oauth/github'
element={
<Suspense fallback={<Loading></Loading>}>
<GitHubOAuth />
</Suspense>
}
/>
<Route
path='/setting'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}> <Suspense fallback={<Loading></Loading>}>
<Home /> <Setting />
</Suspense> </Suspense>
} </PrivateRoute>
/> }
<Route />
path='/channel' <Route
element={ path='/topup'
<PrivateRoute> element={
<Channel /> <PrivateRoute>
</PrivateRoute>
}
/>
<Route
path='/channel/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}> <Suspense fallback={<Loading></Loading>}>
<EditChannel /> <TopUp />
</Suspense> </Suspense>
} </PrivateRoute>
/> }
<Route />
path='/channel/add' <Route
element={ path='/log'
element={
<PrivateRoute>
<Log />
</PrivateRoute>
}
/>
<Route
path='/detail'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}> <Suspense fallback={<Loading></Loading>}>
<EditChannel /> <Detail />
</Suspense> </Suspense>
} </PrivateRoute>
/> }
<Route />
path='/token' <Route
element={ path='/midjourney'
<PrivateRoute> element={
<Token /> <PrivateRoute>
</PrivateRoute>
}
/>
<Route
path='/redemption'
element={
<PrivateRoute>
<Redemption />
</PrivateRoute>
}
/>
<Route
path='/user'
element={
<PrivateRoute>
<User />
</PrivateRoute>
}
/>
<Route
path='/user/edit/:id'
element={
<Suspense fallback={<Loading></Loading>}> <Suspense fallback={<Loading></Loading>}>
<EditUser /> <Midjourney />
</Suspense> </Suspense>
} </PrivateRoute>
/> }
<Route />
path='/user/edit' <Route
element={ path='/task'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}> <Suspense fallback={<Loading></Loading>}>
<EditUser /> <Task />
</Suspense> </Suspense>
} </PrivateRoute>
/> }
/>
<Route
path='/pricing'
element={
<Suspense fallback={<Loading></Loading>}>
<Pricing />
</Suspense>
}
/>
<Route
path='/about'
element={
<Suspense fallback={<Loading></Loading>}>
<About />
</Suspense>
}
/>
<Route
path='/chat/:id?'
element={
<Suspense fallback={<Loading></Loading>}>
<Chat />
</Suspense>
}
/>
{/* 方便使用chat2link直接跳转聊天... */}
<Route <Route
path='/user/reset' path='/chat2link'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetConfirm />
</Suspense>
}
/>
<Route
path='/login'
element={
<Suspense fallback={<Loading></Loading>}>
<LoginForm />
</Suspense>
}
/>
<Route
path='/register'
element={
<Suspense fallback={<Loading></Loading>}>
<RegisterForm />
</Suspense>
}
/>
<Route
path='/reset'
element={
<Suspense fallback={<Loading></Loading>}>
<PasswordResetForm />
</Suspense>
}
/>
<Route
path='/oauth/github'
element={
<Suspense fallback={<Loading></Loading>}>
<GitHubOAuth />
</Suspense>
}
/>
<Route
path='/oauth/linuxdo'
element={
<Suspense fallback={<Loading></Loading>}>
<LinuxDoOAuth />
</Suspense>
}
/>
<Route
path='/setting'
element={ element={
<PrivateRoute> <PrivateRoute>
<Suspense fallback={<Loading></Loading>}> <Suspense fallback={<Loading></Loading>}>
<Setting /> <Chat2Link />
</Suspense> </Suspense>
</PrivateRoute> </PrivateRoute>
} }
/> />
<Route
path='/topup'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<TopUp />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/log'
element={
<PrivateRoute>
<Log />
</PrivateRoute>
}
/>
<Route
path='/detail'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Detail />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/midjourney'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Midjourney />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/task'
element={
<PrivateRoute>
<Suspense fallback={<Loading></Loading>}>
<Task />
</Suspense>
</PrivateRoute>
}
/>
<Route
path='/pricing'
element={
<Suspense fallback={<Loading></Loading>}>
<Pricing />
</Suspense>
}
/>
<Route
path='/about'
element={
<Suspense fallback={<Loading></Loading>}>
<About />
</Suspense>
}
/>
<Route
path='/chat'
element={
<Suspense fallback={<Loading></Loading>}>
<Chat />
</Suspense>
}
/>
<Route path='*' element={<NotFound />} /> <Route path='*' element={<NotFound />} />
</Routes> </Routes>
</Layout.Content> </>
</Layout>
); );
} }

View File

@@ -98,18 +98,14 @@ const ChannelsTable = () => {
render: (text, record, index) => { render: (text, record, index) => {
if (text === 3) { if (text === 3) {
if (record.other_info === '') { if (record.other_info === '') {
record.other_info = '{}'; record.other_info = '{}'
} }
let otherInfo = JSON.parse(record.other_info); let otherInfo = JSON.parse(record.other_info);
let reason = otherInfo['status_reason']; let reason = otherInfo['status_reason'];
let time = otherInfo['status_time']; let time = otherInfo['status_time'];
return ( return (
<div> <div>
<Tooltip <Tooltip content={'原因:' + reason + ',时间:' + timestamp2string(time)}>
content={
'原因:' + reason + ',时间:' + timestamp2string(time)
}
>
{renderStatus(text)} {renderStatus(text)}
</Tooltip> </Tooltip>
</div> </div>
@@ -749,7 +745,7 @@ const ChannelsTable = () => {
<Form.Select <Form.Select
field='group' field='group'
label='分组' label='分组'
optionList={[{ label: '选择分组', value: null }, ...groupOptions]} optionList={[{ label: '选择分组', value: null}, ...groupOptions]}
initValue={null} initValue={null}
onChange={(v) => { onChange={(v) => {
setSearchGroup(v); setSearchGroup(v);

View File

@@ -3,7 +3,7 @@ import React, { useEffect, useState } from 'react';
import { getFooterHTML, getSystemName } from '../helpers'; import { getFooterHTML, getSystemName } from '../helpers';
import { Layout, Tooltip } from '@douyinfe/semi-ui'; import { Layout, Tooltip } from '@douyinfe/semi-ui';
const Footer = () => { const FooterBar = () => {
const systemName = getSystemName(); const systemName = getSystemName();
const [footer, setFooter] = useState(getFooterHTML()); const [footer, setFooter] = useState(getFooterHTML());
let remainCheckTimes = 5; let remainCheckTimes = 5;
@@ -25,7 +25,11 @@ const Footer = () => {
New API {import.meta.env.VITE_REACT_APP_VERSION}{' '} New API {import.meta.env.VITE_REACT_APP_VERSION}{' '}
</a> </a>
{' '} {' '}
<a href='https://github.com/Calcium-Ion' target='_blank' rel='noreferrer'> <a
href='https://github.com/Calcium-Ion'
target='_blank'
rel='noreferrer'
>
Calcium-Ion Calcium-Ion
</a>{' '} </a>{' '}
开发基于{' '} 开发基于{' '}
@@ -52,21 +56,17 @@ const Footer = () => {
}, []); }, []);
return ( return (
<Layout> <div style={{ textAlign: 'center' }}>
<Layout.Content style={{ textAlign: 'center' }}> {footer ? (
{footer ? ( <div
<Tooltip content={defaultFooter}> className='custom-footer'
<div dangerouslySetInnerHTML={{ __html: footer }}
className='custom-footer' ></div>
dangerouslySetInnerHTML={{ __html: footer }} ) : (
></div> defaultFooter
</Tooltip> )}
) : ( </div>
defaultFooter
)}
</Layout.Content>
</Layout>
); );
}; };
export default Footer; export default FooterBar;

View File

@@ -1,8 +1,9 @@
import React, { useContext, useEffect, useState } from 'react'; import React, { useContext, useEffect, useState } from 'react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react'; import { Dimmer, Loader, Segment } from 'semantic-ui-react';
import { useNavigate, useSearchParams } from 'react-router-dom'; import { useNavigate, useSearchParams } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers'; import { API, showError, showSuccess, updateAPI } from '../helpers';
import { UserContext } from '../context/User'; import { UserContext } from '../context/User';
import { setUserData } from '../helpers/data.js';
const GitHubOAuth = () => { const GitHubOAuth = () => {
const [searchParams, setSearchParams] = useSearchParams(); const [searchParams, setSearchParams] = useSearchParams();
@@ -14,22 +15,19 @@ const GitHubOAuth = () => {
let navigate = useNavigate(); let navigate = useNavigate();
const sendCode = async (code, state, count) => { const sendCode = async (code, state, count) => {
let aff = localStorage.getItem('aff'); const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
const res = await API.get(
`/api/oauth/github?code=${code}&state=${state}&aff=${aff}`,
);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
localStorage.removeItem('aff');
if (message === 'bind') { if (message === 'bind') {
showSuccess('绑定成功!'); showSuccess('绑定成功!');
navigate('/setting'); navigate('/setting');
} else { } else {
userDispatch({ type: 'login', payload: data }); userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data)); localStorage.setItem('user', JSON.stringify(data));
setUserData(data);
updateAPI()
showSuccess('登录成功!'); showSuccess('登录成功!');
navigate('/'); navigate('/token');
} }
} else { } else {
showError(message); showError(message);
@@ -46,14 +44,6 @@ const GitHubOAuth = () => {
}; };
useEffect(() => { useEffect(() => {
let error = searchParams.get('error');
if (error) {
let errorDescription = searchParams.get('error_description');
showError(`授权错误:${error}: ${errorDescription}`);
navigate('/setting');
return;
}
let code = searchParams.get('code'); let code = searchParams.get('code');
let state = searchParams.get('state'); let state = searchParams.get('state');
sendCode(code, state, 0).then(); sendCode(code, state, 0).then();

View File

@@ -3,14 +3,23 @@ import { Link, useNavigate } from 'react-router-dom';
import { UserContext } from '../context/User'; import { UserContext } from '../context/User';
import { useSetTheme, useTheme } from '../context/Theme'; import { useSetTheme, useTheme } from '../context/Theme';
import { API, getLogo, getSystemName, showSuccess } from '../helpers'; import { API, getLogo, getSystemName, isMobile, showSuccess } from '../helpers';
import '../index.css'; import '../index.css';
import fireworks from 'react-fireworks'; import fireworks from 'react-fireworks';
import { IconHelpCircle, IconKey, IconUser } from '@douyinfe/semi-icons'; import {
IconHelpCircle,
IconHome,
IconHomeStroked,
IconKey,
IconNoteMoneyStroked,
IconPriceTag,
IconUser
} from '@douyinfe/semi-icons';
import { Avatar, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui'; import { Avatar, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui';
import { stringToColor } from '../helpers/render'; import { stringToColor } from '../helpers/render';
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
// HeaderBar Buttons // HeaderBar Buttons
let headerButtons = [ let headerButtons = [
@@ -22,6 +31,21 @@ let headerButtons = [
}, },
]; ];
let buttons = [
{
text: '首页',
itemKey: 'home',
to: '/',
// icon: <IconHomeStroked />,
},
// {
// text: 'Playground',
// itemKey: 'playground',
// to: '/playground',
// // icon: <IconNoteMoneyStroked />,
// },
];
if (localStorage.getItem('chat_link')) { if (localStorage.getItem('chat_link')) {
headerButtons.splice(1, 0, { headerButtons.splice(1, 0, {
name: '聊天', name: '聊天',
@@ -90,6 +114,7 @@ const HeaderBar = () => {
about: '/about', about: '/about',
login: '/login', login: '/login',
register: '/register', register: '/register',
home: '/',
}; };
return ( return (
<Link <Link
@@ -103,6 +128,18 @@ const HeaderBar = () => {
selectedKeys={[]} selectedKeys={[]}
// items={headerButtons} // items={headerButtons}
onSelect={(key) => {}} onSelect={(key) => {}}
header={isMobile()?{
logo: (
<img src={logo} alt='logo' style={{ marginRight: '0.75em' }} />
),
}:{
logo: (
<img src={logo} alt='logo' />
),
text: systemName,
}}
items={buttons}
footer={ footer={
<> <>
{isNewYear && ( {isNewYear && (
@@ -121,15 +158,19 @@ const HeaderBar = () => {
</Dropdown> </Dropdown>
)} )}
<Nav.Item itemKey={'about'} icon={<IconHelpCircle />} /> <Nav.Item itemKey={'about'} icon={<IconHelpCircle />} />
<Switch <>
checkedText='🌞' {!isMobile() && (
size={'large'} <Switch
checked={theme === 'dark'} checkedText='🌞'
uncheckedText='🌙' size={'large'}
onChange={(checked) => { checked={theme === 'dark'}
setTheme(checked); uncheckedText='🌙'
}} onChange={(checked) => {
/> setTheme(checked);
}}
/>
)}
</>
{userState.user ? ( {userState.user ? (
<> <>
<Dropdown <Dropdown
@@ -155,7 +196,7 @@ const HeaderBar = () => {
<Nav.Item <Nav.Item
itemKey={'login'} itemKey={'login'}
text={'登录'} text={'登录'}
icon={<IconKey />} // icon={<IconKey />}
/> />
<Nav.Item <Nav.Item
itemKey={'register'} itemKey={'register'}

View File

@@ -1,27 +0,0 @@
import React from 'react';
import { Icon } from '@douyinfe/semi-ui';
const LinuxDoIcon = (props) => {
function CustomIcon() {
return (
<svg
className='icon'
viewBox='0 0 24 24'
version='1.1'
xmlns='http://www.w3.org/2000/svg'
width='1em'
height='1em'
{...props}
>
<path
d='M19.7,17.6c-0.1-0.2-0.2-0.4-0.2-0.6c0-0.4-0.2-0.7-0.5-1c-0.1-0.1-0.3-0.2-0.4-0.2c0.6-1.8-0.3-3.6-1.3-4.9c0,0,0,0,0,0c-0.8-1.2-2-2.1-1.9-3.7c0-1.9,0.2-5.4-3.3-5.1C8.5,2.3,9.5,6,9.4,7.3c0,1.1-0.5,2.2-1.3,3.1c-0.2,0.2-0.4,0.5-0.5,0.7c-1,1.2-1.5,2.8-1.5,4.3c-0.2,0.2-0.4,0.4-0.5,0.6c-0.1,0.1-0.2,0.2-0.2,0.3c-0.1,0.1-0.3,0.2-0.5,0.3c-0.4,0.1-0.7,0.3-0.9,0.7c-0.1,0.3-0.2,0.7-0.1,1.1c0.1,0.2,0.1,0.4,0,0.7c-0.2,0.4-0.2,0.9,0,1.4c0.3,0.4,0.8,0.5,1.5,0.6c0.5,0,1.1,0.2,1.6,0.4l0,0c0.5,0.3,1.1,0.5,1.7,0.5c0.3,0,0.7-0.1,1-0.2c0.3-0.2,0.5-0.4,0.6-0.7c0.4,0,1-0.2,1.7-0.2c0.6,0,1.2,0.2,2,0.1c0,0.1,0,0.2,0.1,0.3c0.2,0.5,0.7,0.9,1.3,1c0.1,0,0.1,0,0.2,0c0.8-0.1,1.6-0.5,2.1-1.1l0,0c0.4-0.4,0.9-0.7,1.4-0.9c0.6-0.3,1-0.5,1.1-1C20.3,18.6,20.1,18.2,19.7,17.6z M12.8,4.8c0.6,0.1,1.1,0.6,1,1.2c0,0.3-0.1,0.6-0.3,0.9c0,0,0,0-0.1,0c-0.2-0.1-0.3-0.1-0.4-0.2c0.1-0.1,0.1-0.3,0.2-0.5c0-0.4-0.2-0.7-0.4-0.7c-0.3,0-0.5,0.3-0.5,0.7c0,0,0,0.1,0,0.1c-0.1-0.1-0.3-0.1-0.4-0.2c0,0,0-0.1,0-0.1C11.8,5.5,12.2,4.9,12.8,4.8z M12.5,6.8c0.1,0.1,0.3,0.2,0.4,0.2c0.1,0,0.3,0.1,0.4,0.2c0.2,0.1,0.4,0.2,0.4,0.5c0,0.3-0.3,0.6-0.9,0.8c-0.2,0.1-0.3,0.1-0.4,0.2c-0.3,0.2-0.6,0.3-1,0.3c-0.3,0-0.6-0.2-0.8-0.4c-0.1-0.1-0.2-0.2-0.4-0.3C10.1,8.2,9.9,8,9.8,7.7c0-0.1,0.1-0.2,0.2-0.3c0.3-0.2,0.4-0.3,0.5-0.4l0.1-0.1c0.2-0.3,0.6-0.5,1-0.5C11.9,6.5,12.2,6.6,12.5,6.8z M10.4,5c0.4,0,0.7,0.4,0.8,1.1c0,0.1,0,0.1,0,0.2c-0.1,0-0.3,0.1-0.4,0.2c0,0,0-0.1,0-0.2c0-0.3-0.2-0.6-0.4-0.5c-0.2,0-0.3,0.3-0.3,0.6c0,0.2,0.1,0.3,0.2,0.4l0,0c0,0-0.1,0.1-0.2,0.1C9.9,6.7,9.7,6.4,9.7,6.1C9.7,5.5,10,5,10.4,5z M9.4,21.1c-0.7,0.3-1.6,0.2-2.2-0.2c-0.6-0.3-1.1-0.4-1.8-0.4c-0.5-0.1-1-0.1-1.1-0.3c-0.1-0.2-0.1-0.5,0.1-1c0.1-0.3,0.1-0.6,0-0.9c-0.1-0.3-0.1-0.5,0-0.8C4.5,17.2,4.7,17.1,5,17c0.3-0.1,0.5-0.2,0.7-0.4c0.1-0.1,0.2-0.2,0.3-0.4c0.3-0.4,0.5-0.6,0.8-0.6c0.6,0.1,1.1,1,1.5,1.9c0.2,0.3,0.4,0.7,0.7,1c0.4,0.5,0.9,1.2,0.9,1.6C9.9,20.6,9.7,20.9,9.4,21.1z M14.3,18.9c0,0.1,0,0.1-0.1,0.2c-1.2,0.9-2.8,1-4.1,0.3c-0.2-0.3-0.4-0.6-0.6-0.9c0.9-0.1,0.7-1.3-1.2-2.5c-2-1.3-0.6-3.7,0.1-4.8c0.1-0.1,0.1,0-0.3,0.8c-0.3,0.6-0.9,2.1-0.1,3.2c0-0.8,0.2-1.6,0.5-2.4c0.7-1.3,1.2-2.8,1.5-4.3c0.1,0.1,0.1,0.1,0.2,0.1c0.1,0.1,0.2,0.2,0.3,0.2c0.2,0.3,0.6,0.4,0.9,0.4c0,0,0.1,0,0.1,0c0.4,0,0.8-0.1,1.1-0.4c0.1-0.1,0.2-0.2,0.4-0.2c0.3-0.1,0.6-0.3,0.9-0.6c0.4,1.3,0.8,2.5,1.4,3.6c0.4,0.8,0.7,1.6,0.9,2.5c0.3,0,0.7,0.1,1,0.3c0.8,0.4,1.1,0.7,1,1.2c-0.1,0-0.1,0-0.2,0c0-0.3-0.2-0.6-0.9-0.9c-0.7-0.3-1.3-0.3-1.5,0.4c-0.1,0-0.2,0.1-0.3,0.1c-0.8,0.4-0.8,1.5-0.9,2.6C14.5,18.2,14.4,18.5,14.3,18.9z M18.9,19.5c-0.6,0.2-1.1,0.6-1.5,1.1c-0.4,0.6-1.1,1-1.9,0.9c-0.4,0-0.8-0.3-0.9-0.7c-0.1-0.6-0.1-1.2,0.2-1.8c0.1-0.4,0.2-0.7,0.3-1.1c0.1-1.2,0.1-1.9,0.6-2.2h0c0,0.5,0.3,0.8,0.7,1c0.5,0,1-0.1,1.4-0.5c0.1,0,0.1,0,0.2,0c0.3,0,0.5,0,0.7,0.2c0.2,0.2,0.3,0.5,0.3,0.7c0,0.3,0.2,0.6,0.3,0.9c0.5,0.5,0.5,0.8,0.5,0.9C19.7,19.1,19.3,19.3,18.9,19.5z M9.9,7.5c-0.1,0-0.1,0-0.1,0.1c0,0,0,0.1,0.1,0.1c0,0,0,0,0,0c0.1,0,0.1,0.1,0.1,0.1c0.3,0.4,0.8,0.6,1.4,0.7c0.5-0.1,1-0.2,1.5-0.6c0.2-0.1,0.4-0.2,0.6-0.3c0.1,0,0.1-0.1,0.1-0.1c0-0.1,0-0.1-0.1-0.1l0,0c-0.2,0.1-0.5,0.2-0.7,0.3c-0.4,0.3-0.9,0.5-1.4,0.5c-0.5,0-0.9-0.3-1.2-0.6C10.1,7.6,10,7.5,9.9,7.5z'
fill='currentColor'
/>
</svg>
);
}
return <Icon svg={<CustomIcon />} />;
};
export default LinuxDoIcon;

View File

@@ -1,71 +0,0 @@
import React, { useContext, useEffect, useState } from 'react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers';
import { UserContext } from '../context/User';
const LinuxDoOAuth = () => {
const [searchParams, setSearchParams] = useSearchParams();
const [userState, userDispatch] = useContext(UserContext);
const [prompt, setPrompt] = useState('处理中...');
const [processing, setProcessing] = useState(true);
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
let aff = localStorage.getItem('aff');
const res = await API.get(
`/api/oauth/linuxdo?code=${code}&state=${state}&aff=${aff}`,
);
const { success, message, data } = res.data;
if (success) {
localStorage.removeItem('aff');
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/setting');
} else {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/');
}
} else {
showError(message);
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
navigate('/setting'); // in case this is failed to bind GitHub
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, state, count);
}
};
useEffect(() => {
let error = searchParams.get('error');
if (error) {
let errorDescription = searchParams.get('error_description');
showError(`授权错误:${error}: ${errorDescription}`);
navigate('/setting');
return;
}
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (
<Segment style={{ minHeight: '300px' }}>
<Dimmer active inverted>
<Loader size='large'>{prompt}</Loader>
</Dimmer>
</Segment>
);
};
export default LinuxDoOAuth;

View File

@@ -1,15 +1,8 @@
import React, { useContext, useEffect, useState } from 'react'; import React, { useContext, useEffect, useState } from 'react';
import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User'; import { UserContext } from '../context/User';
import { import { API, getLogo, showError, showInfo, showSuccess, updateAPI } from '../helpers';
API, import { onGitHubOAuthClicked } from './utils';
getLogo,
showError,
showInfo,
showSuccess,
updateAPI,
} from '../helpers';
import { onGitHubOAuthClicked, onLinuxDoOAuthClicked } from './utils';
import Turnstile from 'react-turnstile'; import Turnstile from 'react-turnstile';
import { import {
Button, Button,
@@ -25,7 +18,6 @@ import Text from '@douyinfe/semi-ui/lib/es/typography/text';
import TelegramLoginButton from 'react-telegram-login'; import TelegramLoginButton from 'react-telegram-login';
import { IconGithubLogo } from '@douyinfe/semi-icons'; import { IconGithubLogo } from '@douyinfe/semi-icons';
import LinuxDoIcon from './LinuxDoIcon';
import WeChatIcon from './WeChatIcon'; import WeChatIcon from './WeChatIcon';
import { setUserData } from '../helpers/data.js'; import { setUserData } from '../helpers/data.js';
@@ -79,6 +71,8 @@ const LoginForm = () => {
if (success) { if (success) {
userDispatch({ type: 'login', payload: data }); userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data)); localStorage.setItem('user', JSON.stringify(data));
setUserData(data);
updateAPI()
navigate('/'); navigate('/');
showSuccess('登录成功!'); showSuccess('登录成功!');
setShowWeChatLoginModal(false); setShowWeChatLoginModal(false);
@@ -109,7 +103,7 @@ const LoginForm = () => {
if (success) { if (success) {
userDispatch({ type: 'login', payload: data }); userDispatch({ type: 'login', payload: data });
setUserData(data); setUserData(data);
updateAPI(); updateAPI()
showSuccess('登录成功!'); showSuccess('登录成功!');
if (username === 'root' && password === '123456') { if (username === 'root' && password === '123456') {
Modal.error({ Modal.error({
@@ -151,6 +145,8 @@ const LoginForm = () => {
userDispatch({ type: 'login', payload: data }); userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data)); localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!'); showSuccess('登录成功!');
setUserData(data);
updateAPI()
navigate('/'); navigate('/');
} else { } else {
showError(message); showError(message);
@@ -217,7 +213,6 @@ const LoginForm = () => {
</Text> </Text>
</div> </div>
{status.github_oauth || {status.github_oauth ||
status.linuxdo_oauth ||
status.wechat_login || status.wechat_login ||
status.telegram_oauth ? ( status.telegram_oauth ? (
<> <>
@@ -242,43 +237,35 @@ const LoginForm = () => {
) : ( ) : (
<></> <></>
)} )}
{status.linuxdo_oauth ? (
<Button
type='primary'
icon={<LinuxDoIcon />}
style={{ color: '#000', margin: '0 5px' }}
onClick={() =>
onLinuxDoOAuthClicked(status.linuxdo_client_id)
}
/>
) : (
<></>
)}
{status.wechat_login ? ( {status.wechat_login ? (
<Button <Button
type='primary' type='primary'
style={{ style={{ color: 'rgba(var(--semi-green-5), 1)' }}
color: 'rgba(var(--semi-green-5), 1)',
margin: '0 5px',
}}
icon={<Icon svg={<WeChatIcon />} />} icon={<Icon svg={<WeChatIcon />} />}
onClick={onWeChatLoginClicked} onClick={onWeChatLoginClicked}
/> />
) : ( ) : (
<></> <></>
)} )}
{status.telegram_oauth ? (
<TelegramLoginButton
className='semi-button semi-button-with-icon semi-button-with-icon-only'
buttonSize='medium'
dataOnauth={onTelegramLoginClicked}
botName={status.telegram_bot_name}
/>
) : (
<></>
)}
</div> </div>
{status.telegram_oauth ? (
<>
<div
style={{
display: 'flex',
justifyContent: 'center',
marginTop: 5,
}}
>
<TelegramLoginButton
dataOnauth={onTelegramLoginClicked}
botName={status.telegram_bot_name}
/>
</div>
</>
) : (
<></>
)}
</> </>
) : ( ) : (
<></> <></>

View File

@@ -250,7 +250,7 @@ const LogsTable = () => {
title: '类型', title: '类型',
dataIndex: 'type', dataIndex: 'type',
render: (text, record, index) => { render: (text, record, index) => {
return <div>{renderType(text)}</div>; return <>{renderType(text)}</>;
}, },
}, },
{ {
@@ -258,7 +258,7 @@ const LogsTable = () => {
dataIndex: 'model_name', dataIndex: 'model_name',
render: (text, record, index) => { render: (text, record, index) => {
return record.type === 0 || record.type === 2 ? ( return record.type === 0 || record.type === 2 ? (
<div> <>
<Tag <Tag
color={stringToColor(text)} color={stringToColor(text)}
size='large' size='large'
@@ -269,7 +269,7 @@ const LogsTable = () => {
{' '} {' '}
{text}{' '} {text}{' '}
</Tag> </Tag>
</div> </>
) : ( ) : (
<></> <></>
); );
@@ -282,22 +282,22 @@ const LogsTable = () => {
if (record.is_stream) { if (record.is_stream) {
let other = getLogOther(record.other); let other = getLogOther(record.other);
return ( return (
<div> <>
<Space> <Space>
{renderUseTime(text)} {renderUseTime(text)}
{renderFirstUseTime(other.frt)} {renderFirstUseTime(other.frt)}
{renderIsStream(record.is_stream)} {renderIsStream(record.is_stream)}
</Space> </Space>
</div> </>
); );
} else { } else {
return ( return (
<div> <>
<Space> <Space>
{renderUseTime(text)} {renderUseTime(text)}
{renderIsStream(record.is_stream)} {renderIsStream(record.is_stream)}
</Space> </Space>
</div> </>
); );
} }
}, },
@@ -307,7 +307,7 @@ const LogsTable = () => {
dataIndex: 'prompt_tokens', dataIndex: 'prompt_tokens',
render: (text, record, index) => { render: (text, record, index) => {
return record.type === 0 || record.type === 2 ? ( return record.type === 0 || record.type === 2 ? (
<div>{<span> {text} </span>}</div> <>{<span> {text} </span>}</>
) : ( ) : (
<></> <></>
); );
@@ -319,7 +319,7 @@ const LogsTable = () => {
render: (text, record, index) => { render: (text, record, index) => {
return parseInt(text) > 0 && return parseInt(text) > 0 &&
(record.type === 0 || record.type === 2) ? ( (record.type === 0 || record.type === 2) ? (
<div>{<span> {text} </span>}</div> <>{<span> {text} </span>}</>
) : ( ) : (
<></> <></>
); );
@@ -330,7 +330,7 @@ const LogsTable = () => {
dataIndex: 'quota', dataIndex: 'quota',
render: (text, record, index) => { render: (text, record, index) => {
return record.type === 0 || record.type === 2 ? ( return record.type === 0 || record.type === 2 ? (
<div>{renderQuota(text, 6)}</div> <>{renderQuota(text, 6)}</>
) : ( ) : (
<></> <></>
); );

View File

@@ -92,9 +92,9 @@ function renderType(type) {
); );
case 'UPLOAD': case 'UPLOAD':
return ( return (
<Tag color='blue' size='large'> <Tag color='blue' size='large'>
上传文件 上传文件
</Tag> </Tag>
); );
case 'SHORTEN': case 'SHORTEN':
return ( return (
@@ -262,7 +262,7 @@ function renderDuration(submit_time, finishTime) {
// 返回带有样式的颜色标签 // 返回带有样式的颜色标签
return ( return (
<Tag color={color} size='large'> <Tag color={color} size="large">
{durationSec} {durationSec}
</Tag> </Tag>
); );

View File

@@ -1,5 +1,5 @@
import React, { useContext, useEffect, useRef, useMemo, useState } from 'react'; import React, { useContext, useEffect, useRef, useMemo, useState } from 'react';
import { API, copy, showError, showSuccess } from '../helpers'; import { API, copy, showError, showInfo, showSuccess } from '../helpers';
import { import {
Banner, Banner,
@@ -46,33 +46,37 @@ function renderQuotaType(type) {
function renderAvailable(available) { function renderAvailable(available) {
return available ? ( return available ? (
<Popover <Popover
content={<div style={{ padding: 8 }}>您的分组可以使用该模型</div>} content={
position='top' <div style={{ padding: 8 }}>您的分组可以使用该模型</div>
key={available} }
style={{ position='top'
backgroundColor: 'rgba(var(--semi-blue-4),1)', key={available}
borderColor: 'rgba(var(--semi-blue-4),1)', style={{
color: 'var(--semi-color-white)', backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderWidth: 1, borderColor: 'rgba(var(--semi-blue-4),1)',
borderStyle: 'solid', color: 'var(--semi-color-white)',
}} borderWidth: 1,
borderStyle: 'solid',
}}
> >
<IconVerify style={{ color: 'green' }} size='large' /> <IconVerify style={{ color: 'green' }} size="large" />
</Popover> </Popover>
) : ( ) : (
<Popover <Popover
content={<div style={{ padding: 8 }}>您的分组无权使用该模型</div>} content={
position='top' <div style={{ padding: 8 }}>您的分组无权使用该模型</div>
key={available} }
style={{ position='top'
backgroundColor: 'rgba(var(--semi-blue-4),1)', key={available}
borderColor: 'rgba(var(--semi-blue-4),1)', style={{
color: 'var(--semi-color-white)', backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderWidth: 1, borderColor: 'rgba(var(--semi-blue-4),1)',
borderStyle: 'solid', color: 'var(--semi-color-white)',
}} borderWidth: 1,
borderStyle: 'solid',
}}
> >
<IconUploadError style={{ color: '#FFA54F' }} size='large' /> <IconUploadError style={{ color: '#FFA54F' }} size="large" />
</Popover> </Popover>
); );
} }
@@ -83,14 +87,15 @@ const ModelPricing = () => {
const [selectedRowKeys, setSelectedRowKeys] = useState([]); const [selectedRowKeys, setSelectedRowKeys] = useState([]);
const [modalImageUrl, setModalImageUrl] = useState(''); const [modalImageUrl, setModalImageUrl] = useState('');
const [isModalOpenurl, setIsModalOpenurl] = useState(false); const [isModalOpenurl, setIsModalOpenurl] = useState(false);
const [selectedGroup, setSelectedGroup] = useState('default');
const rowSelection = useMemo( const rowSelection = useMemo(
() => ({ () => ({
onChange: (selectedRowKeys, selectedRows) => { onChange: (selectedRowKeys, selectedRows) => {
setSelectedRowKeys(selectedRowKeys); setSelectedRowKeys(selectedRowKeys);
}, },
}), }),
[], []
); );
const handleChange = (value) => { const handleChange = (value) => {
@@ -116,7 +121,8 @@ const ModelPricing = () => {
title: '可用性', title: '可用性',
dataIndex: 'available', dataIndex: 'available',
render: (text, record, index) => { render: (text, record, index) => {
return renderAvailable(text); // if record.enable_groups contains selectedGroup, then available is true
return renderAvailable(record.enable_groups.includes(selectedGroup));
}, },
sorter: (a, b) => a.available - b.available, sorter: (a, b) => a.available - b.available,
}, },
@@ -162,25 +168,58 @@ const ModelPricing = () => {
}, },
sorter: (a, b) => a.quota_type - b.quota_type, sorter: (a, b) => a.quota_type - b.quota_type,
}, },
{
title: '可用分组',
dataIndex: 'enable_groups',
render: (text, record, index) => {
// enable_groups is a string array
return (
<Space>
{text.map((group) => {
if (group === selectedGroup) {
return (
<Tag
color='blue'
size='large'
prefixIcon={<IconVerify />}
>
{group}
</Tag>
);
} else {
return (
<Tag
color='blue'
size='large'
onClick={() => {
setSelectedGroup(group);
showInfo('当前查看的分组为:' + group + ',倍率为:' + groupRatio[group]);
}}
>
{group}
</Tag>
);
}
})}
</Space>
);
},
},
{ {
title: () => ( title: () => (
<span style={{ display: 'flex', alignItems: 'center' }}> <span style={{'display':'flex','alignItems':'center'}}>
倍率 倍率
<Popover <Popover
content={ content={
<div style={{ padding: 8 }}> <div style={{ padding: 8 }}>倍率是为了方便换算不同价格的模型<br/>点击查看倍率说明</div>
倍率是为了方便换算不同价格的模型
<br />
点击查看倍率说明
</div>
} }
position='top' position='top'
style={{ style={{
backgroundColor: 'rgba(var(--semi-blue-4),1)', backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderColor: 'rgba(var(--semi-blue-4),1)', borderColor: 'rgba(var(--semi-blue-4),1)',
color: 'var(--semi-color-white)', color: 'var(--semi-color-white)',
borderWidth: 1, borderWidth: 1,
borderStyle: 'solid', borderStyle: 'solid',
}} }}
> >
<IconHelpCircle <IconHelpCircle
@@ -200,9 +239,9 @@ const ModelPricing = () => {
<> <>
<Text>模型{record.quota_type === 0 ? text : '无'}</Text> <Text>模型{record.quota_type === 0 ? text : '无'}</Text>
<br /> <br />
<Text> <Text>补全{record.quota_type === 0 ? completionRatio : '无'}</Text>
补全{record.quota_type === 0 ? completionRatio : '无'} <br />
</Text> <Text>分组{groupRatio[selectedGroup]}</Text>
</> </>
); );
return <div>{content}</div>; return <div>{content}</div>;
@@ -215,12 +254,11 @@ const ModelPricing = () => {
let content = text; let content = text;
if (record.quota_type === 0) { if (record.quota_type === 0) {
// 这里的 *2 是因为 1倍率=0.002刀,请勿删除 // 这里的 *2 是因为 1倍率=0.002刀,请勿删除
let inputRatioPrice = record.model_ratio * 2 * record.group_ratio; let inputRatioPrice = record.model_ratio * 2 * groupRatio[selectedGroup];
let completionRatioPrice = let completionRatioPrice =
record.model_ratio * record.model_ratio *
record.completion_ratio * record.completion_ratio * 2 *
2 * groupRatio[selectedGroup];
record.group_ratio;
content = ( content = (
<> <>
<Text>提示 ${inputRatioPrice} / 1M tokens</Text> <Text>提示 ${inputRatioPrice} / 1M tokens</Text>
@@ -229,7 +267,7 @@ const ModelPricing = () => {
</> </>
); );
} else { } else {
let price = parseFloat(text) * record.group_ratio; let price = parseFloat(text) * groupRatio[selectedGroup];
content = <>模型价格${price}</>; content = <>模型价格${price}</>;
} }
return <div>{content}</div>; return <div>{content}</div>;
@@ -240,12 +278,12 @@ const ModelPricing = () => {
const [models, setModels] = useState([]); const [models, setModels] = useState([]);
const [loading, setLoading] = useState(true); const [loading, setLoading] = useState(true);
const [userState, userDispatch] = useContext(UserContext); const [userState, userDispatch] = useContext(UserContext);
const [groupRatio, setGroupRatio] = useState(1); const [groupRatio, setGroupRatio] = useState({});
const setModelsFormat = (models, groupRatio) => { const setModelsFormat = (models, groupRatio) => {
for (let i = 0; i < models.length; i++) { for (let i = 0; i < models.length; i++) {
models[i].key = models[i].model_name; models[i].key = models[i].model_name;
models[i].group_ratio = groupRatio; models[i].group_ratio = groupRatio[models[i].model_name];
} }
// sort by quota_type // sort by quota_type
models.sort((a, b) => { models.sort((a, b) => {
@@ -278,6 +316,7 @@ const ModelPricing = () => {
const { success, message, data, group_ratio } = res.data; const { success, message, data, group_ratio } = res.data;
if (success) { if (success) {
setGroupRatio(group_ratio); setGroupRatio(group_ratio);
setSelectedGroup(userState.user ? userState.user.group : 'default')
setModelsFormat(data, group_ratio); setModelsFormat(data, group_ratio);
} else { } else {
showError(message); showError(message);
@@ -307,40 +346,35 @@ const ModelPricing = () => {
<Layout> <Layout>
{userState.user ? ( {userState.user ? (
<Banner <Banner
type='success' type="success"
fullMode={false} fullMode={false}
closeIcon='null' closeIcon="null"
description={`您的分组为:${userState.user.group},分组倍率为:${groupRatio}`} description={`您的默认分组为:${userState.user.group},分组倍率为:${groupRatio[userState.user.group]}`}
/> />
) : ( ) : (
<Banner <Banner
type='warning' type='warning'
fullMode={false} fullMode={false}
closeIcon='null' closeIcon="null"
description={`您还未登陆,显示的价格为默认分组倍率: ${groupRatio}`} description={`您还未登陆,显示的价格为默认分组倍率: ${groupRatio['default']}`}
/> />
)} )}
<br /> <br/>
<Banner <Banner
type='info' type="info"
fullMode={false} fullMode={false}
description={ description={<div>按量计费费用 = 分组倍率 × 模型倍率 × 提示token数 + 补全token数 × 补全倍率/ 500000 单位美元</div>}
<div> closeIcon="null"
按量计费费用 = 分组倍率 × 模型倍率 × 提示token数 + 补全token数 ×
补全倍率/ 500000 单位美元
</div>
}
closeIcon='null'
/> />
<br /> <br/>
<Button <Button
theme='light' theme='light'
type='tertiary' type='tertiary'
style={{ width: 150 }} style={{width: 150}}
onClick={() => { onClick={() => {
copyText(selectedRowKeys); copyText(selectedRowKeys);
}} }}
disabled={selectedRowKeys == ''} disabled={selectedRowKeys == ""}
> >
复制选中模型 复制选中模型
</Button> </Button>

View File

@@ -10,6 +10,7 @@ import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.
import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js'; import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js';
import { API, showError, showSuccess } from '../helpers'; import { API, showError, showSuccess } from '../helpers';
import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
const OperationSetting = () => { const OperationSetting = () => {
let [inputs, setInputs] = useState({ let [inputs, setInputs] = useState({
@@ -23,6 +24,7 @@ const OperationSetting = () => {
CompletionRatio: '', CompletionRatio: '',
ModelPrice: '', ModelPrice: '',
GroupRatio: '', GroupRatio: '',
UserUsableGroups: '',
TopUpLink: '', TopUpLink: '',
ChatLink: '', ChatLink: '',
ChatLink2: '', // 添加的新状态变量 ChatLink2: '', // 添加的新状态变量
@@ -49,6 +51,7 @@ const OperationSetting = () => {
DataExportInterval: 5, DataExportInterval: 5,
DefaultCollapseSidebar: false, // 默认折叠侧边栏 DefaultCollapseSidebar: false, // 默认折叠侧边栏
RetryTimes: 0, RetryTimes: 0,
Chats: "[]",
}); });
let [loading, setLoading] = useState(false); let [loading, setLoading] = useState(false);
@@ -62,6 +65,7 @@ const OperationSetting = () => {
if ( if (
item.key === 'ModelRatio' || item.key === 'ModelRatio' ||
item.key === 'GroupRatio' || item.key === 'GroupRatio' ||
item.key === 'UserUsableGroups' ||
item.key === 'CompletionRatio' || item.key === 'CompletionRatio' ||
item.key === 'ModelPrice' item.key === 'ModelPrice'
) { ) {
@@ -129,6 +133,10 @@ const OperationSetting = () => {
<Card style={{ marginTop: '10px' }}> <Card style={{ marginTop: '10px' }}>
<SettingsCreditLimit options={inputs} refresh={onRefresh} /> <SettingsCreditLimit options={inputs} refresh={onRefresh} />
</Card> </Card>
{/* 聊天设置 */}
<Card style={{ marginTop: '10px' }}>
<SettingsChats options={inputs} refresh={onRefresh} />
</Card>
{/* 倍率设置 */} {/* 倍率设置 */}
<Card style={{ marginTop: '10px' }}> <Card style={{ marginTop: '10px' }}>
<SettingsMagnification options={inputs} refresh={onRefresh} /> <SettingsMagnification options={inputs} refresh={onRefresh} />

View File

@@ -10,7 +10,7 @@ import {
} from '../helpers'; } from '../helpers';
import Turnstile from 'react-turnstile'; import Turnstile from 'react-turnstile';
import { UserContext } from '../context/User'; import { UserContext } from '../context/User';
import { onGitHubOAuthClicked, onLinuxDoOAuthClicked } from './utils'; import { onGitHubOAuthClicked } from './utils';
import { import {
Avatar, Avatar,
Banner, Banner,
@@ -519,39 +519,6 @@ const PersonalSetting = () => {
</div> </div>
</div> </div>
</div> </div>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>LINUX DO</Typography.Text>
<div
style={{ display: 'flex', justifyContent: 'space-between' }}
>
<div>
<Input
value={
userState.user && userState.user.linuxdo_id !== ''
? userState.user.linuxdo_id +
'' +
userState.user.linuxdo_level +
'级)'
: '未绑定'
}
readonly={true}
></Input>
</div>
<div>
<Button
onClick={() => {
onLinuxDoOAuthClicked(status.linuxdo_client_id);
}}
disabled={
(userState.user && userState.user.linuxdo_id !== '') ||
!status.linuxdo_oauth
}
>
{status.linuxdo_oauth ? '绑定' : '未启用'}
</Button>
</div>
</div>
</div>
<div style={{ marginTop: 10 }}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>Telegram</Typography.Text> <Typography.Text strong>Telegram</Typography.Text>

View File

@@ -0,0 +1,342 @@
import React, { useCallback, useContext, useEffect, useState } from 'react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getUserIdFromLocalStorage, showError } from '../helpers';
import { Card, Chat, Input, Layout, Select, Slider, TextArea, Typography } from '@douyinfe/semi-ui';
import { SSE } from 'sse';
const defaultMessage = [
{
role: 'user',
id: '2',
createAt: 1715676751919,
content: "你好",
},
{
role: 'assistant',
id: '3',
createAt: 1715676751919,
content: "你好,请问有什么可以帮助您的吗?",
}
];
let id = 4;
function getId() {
return `${id++}`
}
const Playground = () => {
const [inputs, setInputs] = useState({
model: 'gpt-4o-mini',
group: '',
max_tokens: 0,
temperature: 0,
});
const [searchParams, setSearchParams] = useSearchParams();
const [userState, userDispatch] = useContext(UserContext);
const [status, setStatus] = useState({});
const [systemPrompt, setSystemPrompt] = useState('You are a helpful assistant. You can help me by answering my questions. You can also ask me questions.');
const [message, setMessage] = useState(defaultMessage);
const [models, setModels] = useState([]);
const [groups, setGroups] = useState([]);
const handleInputChange = (name, value) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
useEffect(() => {
if (searchParams.get('expired')) {
showError('未登录或登录已过期,请重新登录!');
}
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
setStatus(status);
}
loadModels();
loadGroups();
}, []);
const loadModels = async () => {
let res = await API.get(`/api/user/models`);
const { success, message, data } = res.data;
if (success) {
let localModelOptions = data.map((model) => ({
label: model,
value: model,
}));
setModels(localModelOptions);
} else {
showError(message);
}
};
const loadGroups = async () => {
let res = await API.get(`/api/user/self/groups`);
const { success, message, data } = res.data;
if (success) {
// return data is a map, key is group name, value is group description
// label is group description, value is group name
let localGroupOptions = Object.keys(data).map((group) => ({
label: data[group],
value: group,
}));
// handleInputChange('group', localGroupOptions[0].value);
if (localGroupOptions.length > 0) {
} else {
localGroupOptions = [{
label: '用户分组',
value: '',
}];
setGroups(localGroupOptions);
}
setGroups(localGroupOptions);
handleInputChange('group', localGroupOptions[0].value);
} else {
showError(message);
}
};
const commonOuterStyle = {
border: '1px solid var(--semi-color-border)',
borderRadius: '16px',
margin: '0px 8px',
}
const getSystemMessage = () => {
if (systemPrompt !== '') {
return {
role: 'system',
id: '1',
createAt: 1715676751919,
content: systemPrompt,
}
}
}
let handleSSE = (payload) => {
let source = new SSE('/pg/chat/completions', {
headers: {
"Content-Type": "application/json",
"New-Api-User": getUserIdFromLocalStorage(),
},
method: "POST",
payload: JSON.stringify(payload),
});
source.addEventListener("message", (e) => {
if (e.data !== "[DONE]") {
let payload = JSON.parse(e.data);
// console.log("Payload: ", payload);
if (payload.choices.length === 0) {
source.close();
completeMessage();
} else {
let text = payload.choices[0].delta.content;
if (text) {
generateMockResponse(text);
}
}
} else {
completeMessage();
}
});
source.addEventListener("error", (e) => {
generateMockResponse(e.data)
completeMessage('error')
});
source.addEventListener("readystatechange", (e) => {
if (e.readyState >= 2) {
if (source.status === undefined) {
source.close();
completeMessage();
}
}
});
source.stream();
}
const onMessageSend = useCallback((content, attachment) => {
console.log("attachment: ", attachment);
setMessage((prevMessage) => {
const newMessage = [
...prevMessage,
{
role: 'user',
content: content,
createAt: Date.now(),
id: getId()
}
];
// 将 getPayload 移到这里
const getPayload = () => {
let systemMessage = getSystemMessage();
let messages = newMessage.map((item) => {
return {
role: item.role,
content: item.content,
}
});
if (systemMessage) {
messages.unshift(systemMessage);
}
return {
messages: messages,
stream: true,
model: inputs.model,
group: inputs.group,
max_tokens: parseInt(inputs.max_tokens),
temperature: inputs.temperature,
};
};
// 使用更新后的消息状态调用 handleSSE
handleSSE(getPayload());
newMessage.push({
role: 'assistant',
content: '',
createAt: Date.now(),
id: getId(),
status: 'loading'
});
return newMessage;
});
}, [getSystemMessage]);
const completeMessage = useCallback((status = 'complete') => {
// console.log("Complete Message: ", status)
setMessage((prevMessage) => {
const lastMessage = prevMessage[prevMessage.length - 1];
// only change the status if the last message is not complete and not error
if (lastMessage.status === 'complete' || lastMessage.status === 'error') {
return prevMessage;
}
return [
...prevMessage.slice(0, -1),
{ ...lastMessage, status: status }
];
});
}, [])
const generateMockResponse = useCallback((content) => {
// console.log("Generate Mock Response: ", content);
setMessage((message) => {
const lastMessage = message[message.length - 1];
let newMessage = {...lastMessage};
if (lastMessage.status === 'loading' || lastMessage.status === 'incomplete') {
newMessage = {
...newMessage,
content: (lastMessage.content || '') + content,
status: 'incomplete'
}
}
return [ ...message.slice(0, -1), newMessage ]
})
}, []);
return (
<Layout style={{height: '100%'}}>
<Layout.Sider>
<Card style={commonOuterStyle}>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>分组</Typography.Text>
</div>
<Select
placeholder={'请选择分组'}
name='group'
required
selection
onChange={(value) => {
handleInputChange('group', value);
}}
value={inputs.group}
autoComplete='new-password'
optionList={groups}
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>模型</Typography.Text>
</div>
<Select
placeholder={'请选择模型'}
name='model'
required
selection
filter
onChange={(value) => {
handleInputChange('model', value);
}}
value={inputs.model}
autoComplete='new-password'
optionList={models}
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>Temperature</Typography.Text>
</div>
<Slider
step={0.1}
min={0.1}
max={1}
value={inputs.temperature}
onChange={(value) => {
handleInputChange('temperature', value);
}}
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>MaxTokens</Typography.Text>
</div>
<Input
placeholder='MaxTokens'
name='max_tokens'
required
autoComplete='new-password'
defaultValue={0}
value={inputs.max_tokens}
onChange={(value) => {
handleInputChange('max_tokens', value);
}}
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>System</Typography.Text>
</div>
<TextArea
placeholder='System Prompt'
name='system'
required
autoComplete='new-password'
autosize
defaultValue={systemPrompt}
// value={systemPrompt}
onChange={(value) => {
setSystemPrompt(value);
}}
/>
</Card>
</Layout.Sider>
<Layout.Content>
<div style={{height: '100%'}}>
<Chat
chatBoxRenderConfig={{
renderChatBoxAction: () => {
return <div></div>
}
}}
style={commonOuterStyle}
chats={message}
onMessageSend={onMessageSend}
showClearContext
onClear={() => {
setMessage([]);
}}
/>
</div>
</Layout.Content>
</Layout>
);
};
export default Playground;

View File

@@ -12,7 +12,7 @@ const RegisterForm = () => {
password: '', password: '',
password2: '', password2: '',
email: '', email: '',
verification_code: '', verification_code: ''
}); });
const { username, password, password2 } = inputs; const { username, password, password2 } = inputs;
const [showEmailVerification, setShowEmailVerification] = useState(false); const [showEmailVerification, setShowEmailVerification] = useState(false);
@@ -65,12 +65,10 @@ const RegisterForm = () => {
inputs.aff_code = affCode; inputs.aff_code = affCode;
const res = await API.post( const res = await API.post(
`/api/user/register?turnstile=${turnstileToken}`, `/api/user/register?turnstile=${turnstileToken}`,
inputs, inputs
); );
const { success, message } = res.data; const { success, message } = res.data;
if (success) { if (success) {
localStorage.removeItem('aff');
navigate('/login'); navigate('/login');
showSuccess('注册成功!'); showSuccess('注册成功!');
} else { } else {
@@ -88,7 +86,7 @@ const RegisterForm = () => {
} }
setLoading(true); setLoading(true);
const res = await API.get( const res = await API.get(
`/api/verification?email=${inputs.email}&turnstile=${turnstileToken}`, `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}`
); );
const { success, message } = res.data; const { success, message } = res.data;
if (success) { if (success) {
@@ -108,7 +106,7 @@ const RegisterForm = () => {
style={{ style={{
justifyContent: 'center', justifyContent: 'center',
display: 'flex', display: 'flex',
marginTop: 120, marginTop: 120
}} }}
> >
<div style={{ width: 500 }}> <div style={{ width: 500 }}>
@@ -116,28 +114,28 @@ const RegisterForm = () => {
<Title heading={2} style={{ textAlign: 'center' }}> <Title heading={2} style={{ textAlign: 'center' }}>
新用户注册 新用户注册
</Title> </Title>
<Form size='large'> <Form size="large">
<Form.Input <Form.Input
field={'username'} field={'username'}
label={'用户名'} label={'用户名'}
placeholder='用户名' placeholder="用户名"
name='username' name="username"
onChange={(value) => handleChange('username', value)} onChange={(value) => handleChange('username', value)}
/> />
<Form.Input <Form.Input
field={'password'} field={'password'}
label={'密码'} label={'密码'}
placeholder='密码,最短 8 位,最长 20 位' placeholder="密码,最短 8 位,最长 20 位"
name='password' name="password"
type='password' type="password"
onChange={(value) => handleChange('password', value)} onChange={(value) => handleChange('password', value)}
/> />
<Form.Input <Form.Input
field={'password2'} field={'password2'}
label={'确认密码'} label={'确认密码'}
placeholder='确认密码' placeholder="确认密码"
name='password2' name="password2"
type='password' type="password"
onChange={(value) => handleChange('password2', value)} onChange={(value) => handleChange('password2', value)}
/> />
{showEmailVerification ? ( {showEmailVerification ? (
@@ -145,15 +143,12 @@ const RegisterForm = () => {
<Form.Input <Form.Input
field={'email'} field={'email'}
label={'邮箱'} label={'邮箱'}
placeholder='输入邮箱地址' placeholder="输入邮箱地址"
onChange={(value) => handleChange('email', value)} onChange={(value) => handleChange('email', value)}
name='email' name="email"
type='email' type="email"
suffix={ suffix={
<Button <Button onClick={sendVerificationCode} disabled={loading}>
onClick={sendVerificationCode}
disabled={loading}
>
获取验证码 获取验证码
</Button> </Button>
} }
@@ -161,11 +156,9 @@ const RegisterForm = () => {
<Form.Input <Form.Input
field={'verification_code'} field={'verification_code'}
label={'验证码'} label={'验证码'}
placeholder='输入验证码' placeholder="输入验证码"
onChange={(value) => onChange={(value) => handleChange('verification_code', value)}
handleChange('verification_code', value) name="verification_code"
}
name='verification_code'
/> />
</> </>
) : ( ) : (
@@ -186,12 +179,14 @@ const RegisterForm = () => {
style={{ style={{
display: 'flex', display: 'flex',
justifyContent: 'space-between', justifyContent: 'space-between',
marginTop: 20, marginTop: 20
}} }}
> >
<Text> <Text>
已有账户 已有账户
<Link to='/login'>点击登录</Link> <Link to="/login">
点击登录
</Link>
</Text> </Text>
</div> </div>
</Card> </Card>

View File

@@ -0,0 +1,790 @@
import React, { useEffect, useState } from 'react';
import {
Button,
Divider,
Form,
Grid,
Header,
Message,
Modal,
} from 'semantic-ui-react';
import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers';
import { useTheme } from '../context/Theme';
const SafetySetting = () => {
let [inputs, setInputs] = useState({
PasswordLoginEnabled: '',
PasswordRegisterEnabled: '',
EmailVerificationEnabled: '',
GitHubOAuthEnabled: '',
GitHubClientId: '',
GitHubClientSecret: '',
Notice: '',
SMTPServer: '',
SMTPPort: '',
SMTPAccount: '',
SMTPFrom: '',
SMTPToken: '',
ServerAddress: '',
WorkerUrl: '',
WorkerValidKey: '',
EpayId: '',
EpayKey: '',
Price: 7.3,
MinTopUp: 1,
TopupGroupRatio: '',
PayAddress: '',
CustomCallbackAddress: '',
Footer: '',
WeChatAuthEnabled: '',
WeChatServerAddress: '',
WeChatServerToken: '',
WeChatAccountQRCodeImageURL: '',
TurnstileCheckEnabled: '',
TurnstileSiteKey: '',
TurnstileSecretKey: '',
RegisterEnabled: '',
EmailDomainRestrictionEnabled: '',
EmailAliasRestrictionEnabled: '',
SMTPSSLEnabled: '',
EmailDomainWhitelist: [],
// telegram login
TelegramOAuthEnabled: '',
TelegramBotToken: '',
TelegramBotName: '',
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]);
const [restrictedDomainInput, setRestrictedDomainInput] = useState('');
const [showPasswordWarningModal, setShowPasswordWarningModal] =
useState(false);
const theme = useTheme();
const isDark = theme === 'dark';
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key === 'TopupGroupRatio') {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
newInputs[item.key] = item.value;
});
setInputs({
...newInputs,
EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(','),
});
setOriginInputs(newInputs);
setEmailDomainWhitelist(
newInputs.EmailDomainWhitelist.split(',').map((item) => {
return { key: item, text: item, value: item };
}),
);
} else {
showError(message);
}
};
useEffect(() => {
getOptions().then();
}, []);
useEffect(() => {}, [inputs.EmailDomainWhitelist]);
const updateOption = async (key, value) => {
setLoading(true);
switch (key) {
case 'PasswordLoginEnabled':
case 'PasswordRegisterEnabled':
case 'EmailVerificationEnabled':
case 'GitHubOAuthEnabled':
case 'WeChatAuthEnabled':
case 'TelegramOAuthEnabled':
case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled':
case 'EmailAliasRestrictionEnabled':
case 'SMTPSSLEnabled':
case 'RegisterEnabled':
value = inputs[key] === 'true' ? 'false' : 'true';
break;
default:
break;
}
const res = await API.put('/api/option/', {
key,
value,
});
const { success, message } = res.data;
if (success) {
if (key === 'EmailDomainWhitelist') {
value = value.split(',');
}
if (key === 'Price') {
value = parseFloat(value);
}
setInputs((inputs) => ({
...inputs,
[key]: value,
}));
} else {
showError(message);
}
setLoading(false);
};
const handleInputChange = async (e, { name, value }) => {
if (name === 'PasswordLoginEnabled' && inputs[name] === 'true') {
// block disabling password login
setShowPasswordWarningModal(true);
return;
}
if (
name === 'Notice' ||
(name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') ||
name === 'ServerAddress' ||
name === 'WorkerUrl' ||
name === 'WorkerValidKey' ||
name === 'EpayId' ||
name === 'EpayKey' ||
name === 'Price' ||
name === 'PayAddress' ||
name === 'GitHubClientId' ||
name === 'GitHubClientSecret' ||
name === 'WeChatServerAddress' ||
name === 'WeChatServerToken' ||
name === 'WeChatAccountQRCodeImageURL' ||
name === 'TurnstileSiteKey' ||
name === 'TurnstileSecretKey' ||
name === 'EmailDomainWhitelist' ||
name === 'TopupGroupRatio' ||
name === 'TelegramBotToken' ||
name === 'TelegramBotName'
) {
setInputs((inputs) => ({ ...inputs, [name]: value }));
} else {
await updateOption(name, value);
}
};
const submitServerAddress = async () => {
let ServerAddress = removeTrailingSlash(inputs.ServerAddress);
await updateOption('ServerAddress', ServerAddress);
};
const submitWorker = async () => {
let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl);
await updateOption('WorkerUrl', WorkerUrl);
if (inputs.WorkerValidKey !== '') {
await updateOption('WorkerValidKey', inputs.WorkerValidKey);
}
}
const submitPayAddress = async () => {
if (inputs.ServerAddress === '') {
showError('请先填写服务器地址');
return;
}
if (originInputs['TopupGroupRatio'] !== inputs.TopupGroupRatio) {
if (!verifyJSON(inputs.TopupGroupRatio)) {
showError('充值分组倍率不是合法的 JSON 字符串');
return;
}
await updateOption('TopupGroupRatio', inputs.TopupGroupRatio);
}
let PayAddress = removeTrailingSlash(inputs.PayAddress);
await updateOption('PayAddress', PayAddress);
if (inputs.EpayId !== '') {
await updateOption('EpayId', inputs.EpayId);
}
if (inputs.EpayKey !== undefined && inputs.EpayKey !== '') {
await updateOption('EpayKey', inputs.EpayKey);
}
await updateOption('Price', '' + inputs.Price);
};
const submitSMTP = async () => {
if (originInputs['SMTPServer'] !== inputs.SMTPServer) {
await updateOption('SMTPServer', inputs.SMTPServer);
}
if (originInputs['SMTPAccount'] !== inputs.SMTPAccount) {
await updateOption('SMTPAccount', inputs.SMTPAccount);
}
if (originInputs['SMTPFrom'] !== inputs.SMTPFrom) {
await updateOption('SMTPFrom', inputs.SMTPFrom);
}
if (
originInputs['SMTPPort'] !== inputs.SMTPPort &&
inputs.SMTPPort !== ''
) {
await updateOption('SMTPPort', inputs.SMTPPort);
}
if (
originInputs['SMTPToken'] !== inputs.SMTPToken &&
inputs.SMTPToken !== ''
) {
await updateOption('SMTPToken', inputs.SMTPToken);
}
};
const submitEmailDomainWhitelist = async () => {
if (
originInputs['EmailDomainWhitelist'] !==
inputs.EmailDomainWhitelist.join(',') &&
inputs.SMTPToken !== ''
) {
await updateOption(
'EmailDomainWhitelist',
inputs.EmailDomainWhitelist.join(','),
);
}
};
const submitWeChat = async () => {
if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) {
await updateOption(
'WeChatServerAddress',
removeTrailingSlash(inputs.WeChatServerAddress),
);
}
if (
originInputs['WeChatAccountQRCodeImageURL'] !==
inputs.WeChatAccountQRCodeImageURL
) {
await updateOption(
'WeChatAccountQRCodeImageURL',
inputs.WeChatAccountQRCodeImageURL,
);
}
if (
originInputs['WeChatServerToken'] !== inputs.WeChatServerToken &&
inputs.WeChatServerToken !== ''
) {
await updateOption('WeChatServerToken', inputs.WeChatServerToken);
}
};
const submitGitHubOAuth = async () => {
if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) {
await updateOption('GitHubClientId', inputs.GitHubClientId);
}
if (
originInputs['GitHubClientSecret'] !== inputs.GitHubClientSecret &&
inputs.GitHubClientSecret !== ''
) {
await updateOption('GitHubClientSecret', inputs.GitHubClientSecret);
}
};
const submitTelegramSettings = async () => {
// await updateOption('TelegramOAuthEnabled', inputs.TelegramOAuthEnabled);
await updateOption('TelegramBotToken', inputs.TelegramBotToken);
await updateOption('TelegramBotName', inputs.TelegramBotName);
};
const submitTurnstile = async () => {
if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) {
await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey);
}
if (
originInputs['TurnstileSecretKey'] !== inputs.TurnstileSecretKey &&
inputs.TurnstileSecretKey !== ''
) {
await updateOption('TurnstileSecretKey', inputs.TurnstileSecretKey);
}
};
const submitNewRestrictedDomain = () => {
const localDomainList = inputs.EmailDomainWhitelist;
if (
restrictedDomainInput !== '' &&
!localDomainList.includes(restrictedDomainInput)
) {
setRestrictedDomainInput('');
setInputs({
...inputs,
EmailDomainWhitelist: [...localDomainList, restrictedDomainInput],
});
setEmailDomainWhitelist([
...EmailDomainWhitelist,
{
key: restrictedDomainInput,
text: restrictedDomainInput,
value: restrictedDomainInput,
},
]);
}
};
return (
<Grid columns={1}>
<Grid.Column>
<Form loading={loading} inverted={isDark}>
<Header as='h3' inverted={isDark}>
通用设置
</Header>
<Form.Group widths='equal'>
<Form.Input
label='服务器地址'
placeholder='例如https://yourdomain.com'
value={inputs.ServerAddress}
name='ServerAddress'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={submitServerAddress}>
更新服务器地址
</Form.Button>
<Header as='h3' inverted={isDark}>
代理设置支持 <a href='https://github.com/Calcium-Ion/new-api-worker' target='_blank' rel='noreferrer'>new-api-worker</a>
</Header>
<Form.Group widths='equal'>
<Form.Input
label='Worker地址不填写则不启用代理'
placeholder='例如https://workername.yourdomain.workers.dev'
value={inputs.WorkerUrl}
name='WorkerUrl'
onChange={handleInputChange}
/>
<Form.Input
label='Worker密钥根据你部署的 Worker 填写'
placeholder='例如your_secret_key'
value={inputs.WorkerValidKey}
name='WorkerValidKey'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={submitWorker}>
更新Worker设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
支付设置当前仅支持易支付接口默认使用上方服务器地址作为回调地址
</Header>
<Form.Group widths='equal'>
<Form.Input
label='支付地址,不填写则不启用在线支付'
placeholder='例如https://yourdomain.com'
value={inputs.PayAddress}
name='PayAddress'
onChange={handleInputChange}
/>
<Form.Input
label='易支付商户ID'
placeholder='例如0001'
value={inputs.EpayId}
name='EpayId'
onChange={handleInputChange}
/>
<Form.Input
label='易支付商户密钥'
placeholder='敏感信息不会发送到前端显示'
value={inputs.EpayKey}
name='EpayKey'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.Input
label='回调地址,不填写则使用上方服务器地址作为回调地址'
placeholder='例如https://yourdomain.com'
value={inputs.CustomCallbackAddress}
name='CustomCallbackAddress'
onChange={handleInputChange}
/>
<Form.Input
label='充值价格x元/美金)'
placeholder='例如7就是7元/美金'
value={inputs.Price}
name='Price'
min={0}
onChange={handleInputChange}
/>
<Form.Input
label='最低充值美元数量(以美金为单位,如果使用额度请自行换算!)'
placeholder='例如2就是最低充值2$'
value={inputs.MinTopUp}
name='MinTopUp'
min={1}
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
label='充值分组倍率'
name='TopupGroupRatio'
onChange={handleInputChange}
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
value={inputs.TopupGroupRatio}
placeholder='为一个 JSON 文本,键为组名称,值为倍率'
/>
</Form.Group>
<Form.Button onClick={submitPayAddress}>更新支付设置</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
配置登录注册
</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.PasswordLoginEnabled === 'true'}
label='允许通过密码进行登录'
name='PasswordLoginEnabled'
onChange={handleInputChange}
/>
{showPasswordWarningModal && (
<Modal
open={showPasswordWarningModal}
onClose={() => setShowPasswordWarningModal(false)}
size={'tiny'}
style={{ maxWidth: '450px' }}
>
<Modal.Header>警告</Modal.Header>
<Modal.Content>
<p>
取消密码登录将导致所有未绑定其他登录方式的用户包括管理员无法通过密码登录确认取消
</p>
</Modal.Content>
<Modal.Actions>
<Button onClick={() => setShowPasswordWarningModal(false)}>
取消
</Button>
<Button
color='yellow'
onClick={async () => {
setShowPasswordWarningModal(false);
await updateOption('PasswordLoginEnabled', 'false');
}}
>
确定
</Button>
</Modal.Actions>
</Modal>
)}
<Form.Checkbox
checked={inputs.PasswordRegisterEnabled === 'true'}
label='允许通过密码进行注册'
name='PasswordRegisterEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.EmailVerificationEnabled === 'true'}
label='通过密码注册时需要进行邮箱验证'
name='EmailVerificationEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.GitHubOAuthEnabled === 'true'}
label='允许通过 GitHub 账户登录 & 注册'
name='GitHubOAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.WeChatAuthEnabled === 'true'}
label='允许通过微信登录 & 注册'
name='WeChatAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.TelegramOAuthEnabled === 'true'}
label='允许通过 Telegram 进行登录'
name='TelegramOAuthEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox
checked={inputs.RegisterEnabled === 'true'}
label='允许新用户注册(此项为否时,新用户将无法以任何方式进行注册)'
name='RegisterEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.TurnstileCheckEnabled === 'true'}
label='启用 Turnstile 用户校验'
name='TurnstileCheckEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Divider />
<Header as='h3' inverted={isDark}>
配置邮箱域名白名单
<Header.Subheader>
用以防止恶意用户利用临时邮箱批量注册
</Header.Subheader>
</Header>
<Form.Group widths={3}>
<Form.Checkbox
label='启用邮箱域名白名单'
name='EmailDomainRestrictionEnabled'
onChange={handleInputChange}
checked={inputs.EmailDomainRestrictionEnabled === 'true'}
/>
</Form.Group>
<Form.Group widths={3}>
<Form.Checkbox
label='启用邮箱别名限制例如ab.cd@gmail.com'
name='EmailAliasRestrictionEnabled'
onChange={handleInputChange}
checked={inputs.EmailAliasRestrictionEnabled === 'true'}
/>
</Form.Group>
<Form.Group widths={2}>
<Form.Dropdown
label='允许的邮箱域名'
placeholder='允许的邮箱域名'
name='EmailDomainWhitelist'
required
fluid
multiple
selection
onChange={handleInputChange}
value={inputs.EmailDomainWhitelist}
autoComplete='new-password'
options={EmailDomainWhitelist}
/>
<Form.Input
label='添加新的允许的邮箱域名'
action={
<Button
type='button'
onClick={() => {
submitNewRestrictedDomain();
}}
>
填入
</Button>
}
onKeyDown={(e) => {
if (e.key === 'Enter') {
submitNewRestrictedDomain();
}
}}
autoComplete='new-password'
placeholder='输入新的允许的邮箱域名'
value={restrictedDomainInput}
onChange={(e, { value }) => {
setRestrictedDomainInput(value);
}}
/>
</Form.Group>
<Form.Button onClick={submitEmailDomainWhitelist}>
保存邮箱域名白名单设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
配置 SMTP
<Header.Subheader>用以支持系统的邮件发送</Header.Subheader>
</Header>
<Form.Group widths={3}>
<Form.Input
label='SMTP 服务器地址'
name='SMTPServer'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.SMTPServer}
placeholder='例如smtp.qq.com'
/>
<Form.Input
label='SMTP 端口'
name='SMTPPort'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.SMTPPort}
placeholder='默认: 587'
/>
<Form.Input
label='SMTP 账户'
name='SMTPAccount'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.SMTPAccount}
placeholder='通常是邮箱地址'
/>
</Form.Group>
<Form.Group widths={3}>
<Form.Input
label='SMTP 发送者邮箱'
name='SMTPFrom'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.SMTPFrom}
placeholder='通常和邮箱地址保持一致'
/>
<Form.Input
label='SMTP 访问凭证'
name='SMTPToken'
onChange={handleInputChange}
type='password'
autoComplete='new-password'
checked={inputs.RegisterEnabled === 'true'}
placeholder='敏感信息不会发送到前端显示'
/>
</Form.Group>
<Form.Group widths={3}>
<Form.Checkbox
label='启用SMTP SSL465端口强制开启'
name='SMTPSSLEnabled'
onChange={handleInputChange}
checked={inputs.SMTPSSLEnabled === 'true'}
/>
</Form.Group>
<Form.Button onClick={submitSMTP}>保存 SMTP 设置</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
配置 GitHub OAuth App
<Header.Subheader>
用以支持通过 GitHub 进行登录注册
<a
href='https://github.com/settings/developers'
target='_blank'
rel='noreferrer'
>
点击此处
</a>
管理你的 GitHub OAuth App
</Header.Subheader>
</Header>
<Message>
Homepage URL <code>{inputs.ServerAddress}</code>
Authorization callback URL {' '}
<code>{`${inputs.ServerAddress}/oauth/github`}</code>
</Message>
<Form.Group widths={3}>
<Form.Input
label='GitHub Client ID'
name='GitHubClientId'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.GitHubClientId}
placeholder='输入你注册的 GitHub OAuth APP 的 ID'
/>
<Form.Input
label='GitHub Client Secret'
name='GitHubClientSecret'
onChange={handleInputChange}
type='password'
autoComplete='new-password'
value={inputs.GitHubClientSecret}
placeholder='敏感信息不会发送到前端显示'
/>
</Form.Group>
<Form.Button onClick={submitGitHubOAuth}>
保存 GitHub OAuth 设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
配置 WeChat Server
<Header.Subheader>
用以支持通过微信进行登录注册
<a
href='https://github.com/songquanpeng/wechat-server'
target='_blank'
rel='noreferrer'
>
点击此处
</a>
了解 WeChat Server
</Header.Subheader>
</Header>
<Form.Group widths={3}>
<Form.Input
label='WeChat Server 服务器地址'
name='WeChatServerAddress'
placeholder='例如https://yourdomain.com'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.WeChatServerAddress}
/>
<Form.Input
label='WeChat Server 访问凭证'
name='WeChatServerToken'
type='password'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.WeChatServerToken}
placeholder='敏感信息不会发送到前端显示'
/>
<Form.Input
label='微信公众号二维码图片链接'
name='WeChatAccountQRCodeImageURL'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.WeChatAccountQRCodeImageURL}
placeholder='输入一个图片链接'
/>
</Form.Group>
<Form.Button onClick={submitWeChat}>
保存 WeChat Server 设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
配置 Telegram 登录
</Header>
<Form.Group inline>
<Form.Input
label='Telegram Bot Token'
name='TelegramBotToken'
onChange={handleInputChange}
value={inputs.TelegramBotToken}
placeholder='输入你的 Telegram Bot Token'
/>
<Form.Input
label='Telegram Bot 名称'
name='TelegramBotName'
onChange={handleInputChange}
value={inputs.TelegramBotName}
placeholder='输入你的 Telegram Bot 名称'
/>
</Form.Group>
<Form.Button onClick={submitTelegramSettings}>
保存 Telegram 登录设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
配置 Turnstile
<Header.Subheader>
用以支持用户校验
<a
href='https://dash.cloudflare.com/'
target='_blank'
rel='noreferrer'
>
点击此处
</a>
管理你的 Turnstile Sites推荐选择 Invisible Widget Type
</Header.Subheader>
</Header>
<Form.Group widths={3}>
<Form.Input
label='Turnstile Site Key'
name='TurnstileSiteKey'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.TurnstileSiteKey}
placeholder='输入你注册的 Turnstile Site Key'
/>
<Form.Input
label='Turnstile Secret Key'
name='TurnstileSecretKey'
onChange={handleInputChange}
type='password'
autoComplete='new-password'
value={inputs.TurnstileSecretKey}
placeholder='敏感信息不会发送到前端显示'
/>
</Form.Group>
<Form.Button onClick={submitTurnstile}>
保存 Turnstile 设置
</Form.Button>
</Form>
</Grid.Column>
</Grid>
);
};
export default SystemSetting;

View File

@@ -14,11 +14,10 @@ import {
import '../index.css'; import '../index.css';
import { import {
IconCalendarClock, IconCalendarClock, IconChecklistStroked,
IconChecklistStroked, IconComment, IconCommentStroked,
IconComment,
IconCreditCard, IconCreditCard,
IconGift, IconGift, IconHelpCircle,
IconHistogram, IconHistogram,
IconHome, IconHome,
IconImage, IconImage,
@@ -26,10 +25,12 @@ import {
IconLayers, IconLayers,
IconPriceTag, IconPriceTag,
IconSetting, IconSetting,
IconUser, IconUser
} from '@douyinfe/semi-icons'; } from '@douyinfe/semi-icons';
import { Layout, Nav } from '@douyinfe/semi-ui'; import { Avatar, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui';
import { setStatusData } from '../helpers/data.js'; import { setStatusData } from '../helpers/data.js';
import { stringToColor } from '../helpers/render.js';
import { useSetTheme, useTheme } from '../context/Theme/index.js';
// HeaderBar Buttons // HeaderBar Buttons
@@ -39,11 +40,11 @@ const SiderBar = () => {
const defaultIsCollapsed = const defaultIsCollapsed =
isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true'; isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true';
let navigate = useNavigate();
const [selectedKeys, setSelectedKeys] = useState(['home']); const [selectedKeys, setSelectedKeys] = useState(['home']);
const systemName = getSystemName();
const logo = getLogo();
const [isCollapsed, setIsCollapsed] = useState(defaultIsCollapsed); const [isCollapsed, setIsCollapsed] = useState(defaultIsCollapsed);
const [chatItems, setChatItems] = useState([]);
const theme = useTheme();
const setTheme = useSetTheme();
const routerMap = { const routerMap = {
home: '/', home: '/',
@@ -60,15 +61,22 @@ const SiderBar = () => {
detail: '/detail', detail: '/detail',
pricing: '/pricing', pricing: '/pricing',
task: '/task', task: '/task',
playground: '/playground',
}; };
const headerButtons = useMemo( const headerButtons = useMemo(
() => [ () => [
{ {
text: '首页', text: 'Playground',
itemKey: 'home', itemKey: 'playground',
to: '/', to: '/playground',
icon: <IconHome />, icon: <IconCommentStroked />,
},
{
text: '模型价格',
itemKey: 'pricing',
to: '/pricing',
icon: <IconPriceTag />,
}, },
{ {
text: '渠道', text: '渠道',
@@ -80,11 +88,12 @@ const SiderBar = () => {
{ {
text: '聊天', text: '聊天',
itemKey: 'chat', itemKey: 'chat',
to: '/chat', // to: '/chat',
items: chatItems,
icon: <IconComment />, icon: <IconComment />,
className: localStorage.getItem('chat_link') // className: localStorage.getItem('chat_link')
? 'semi-navigation-item-normal' // ? 'semi-navigation-item-normal'
: 'tableHiddle', // : 'tableHiddle',
}, },
{ {
text: '令牌', text: '令牌',
@@ -105,12 +114,6 @@ const SiderBar = () => {
to: '/topup', to: '/topup',
icon: <IconCreditCard />, icon: <IconCreditCard />,
}, },
{
text: '模型价格',
itemKey: 'pricing',
to: '/pricing',
icon: <IconPriceTag />,
},
{ {
text: '用户管理', text: '用户管理',
itemKey: 'user', itemKey: 'user',
@@ -150,9 +153,9 @@ const SiderBar = () => {
to: '/task', to: '/task',
icon: <IconChecklistStroked />, icon: <IconChecklistStroked />,
className: className:
localStorage.getItem('enable_task') === 'true' localStorage.getItem('enable_task') === 'true'
? 'semi-navigation-item-normal' ? 'semi-navigation-item-normal'
: 'tableHiddle', : 'tableHiddle',
}, },
{ {
text: '设置', text: '设置',
@@ -171,7 +174,7 @@ const SiderBar = () => {
localStorage.getItem('enable_data_export'), localStorage.getItem('enable_data_export'),
localStorage.getItem('enable_drawing'), localStorage.getItem('enable_drawing'),
localStorage.getItem('enable_task'), localStorage.getItem('enable_task'),
localStorage.getItem('chat_link'), localStorage.getItem('chat_link'), chatItems,
isAdmin(), isAdmin(),
], ],
); );
@@ -202,52 +205,101 @@ const SiderBar = () => {
localKey = 'home'; localKey = 'home';
} }
setSelectedKeys([localKey]); setSelectedKeys([localKey]);
let chatLink = localStorage.getItem('chat_link');
if (!chatLink) {
let chats = localStorage.getItem('chats');
if (chats) {
// console.log(chats);
try {
chats = JSON.parse(chats);
if (Array.isArray(chats)) {
let chatItems = [];
for (let i = 0; i < chats.length; i++) {
let chat = {};
for (let key in chats[i]) {
chat.text = key;
chat.itemKey = 'chat' + i;
chat.to = '/chat/' + i;
}
// setRouterMap({ ...routerMap, chat: '/chat/' + i })
chatItems.push(chat);
}
setChatItems(chatItems);
}
} catch (e) {
console.error(e);
showError('聊天数据解析失败')
}
}
}
}, []); }, []);
return ( return (
<> <>
<Layout> <Nav
<div style={{ height: '100%' }}> style={{ maxWidth: 220, height: '100%' }}
<Nav defaultIsCollapsed={
// bodyStyle={{ maxWidth: 200 }} isMobile() ||
style={{ maxWidth: 200 }} localStorage.getItem('default_collapse_sidebar') === 'true'
defaultIsCollapsed={ }
isMobile() || isCollapsed={isCollapsed}
localStorage.getItem('default_collapse_sidebar') === 'true' onCollapseChange={(collapsed) => {
setIsCollapsed(collapsed);
}}
selectedKeys={selectedKeys}
renderWrapper={({ itemElement, isSubNav, isInSubNav, props }) => {
let chatLink = localStorage.getItem('chat_link');
if (!chatLink) {
let chats = localStorage.getItem('chats');
if (chats) {
chats = JSON.parse(chats);
if (Array.isArray(chats) && chats.length > 0) {
for (let i = 0; i < chats.length; i++) {
routerMap['chat' + i] = '/chat/' + i;
}
if (chats.length > 1) {
// delete /chat
if (routerMap['chat']) {
delete routerMap['chat'];
}
} else {
// rename /chat to /chat/0
routerMap['chat'] = '/chat/0';
}
}
}
} }
isCollapsed={isCollapsed} return (
onCollapseChange={(collapsed) => { <Link
setIsCollapsed(collapsed); style={{ textDecoration: 'none' }}
}} to={routerMap[props.itemKey]}
selectedKeys={selectedKeys} >
renderWrapper={({ itemElement, isSubNav, isInSubNav, props }) => { {itemElement}
return ( </Link>
<Link );
style={{ textDecoration: 'none' }} }}
to={routerMap[props.itemKey]} items={headerButtons}
> onSelect={(key) => {
{itemElement} setSelectedKeys([key.itemKey]);
</Link> }}
); footer={
}} <>
items={headerButtons} {isMobile() && (
onSelect={(key) => { <Switch
setSelectedKeys([key.itemKey]); checkedText='🌞'
}} size={'small'}
header={{ checked={theme === 'dark'}
logo: ( uncheckedText='🌙'
<img src={logo} alt='logo' style={{ marginRight: '0.75em' }} /> onChange={(checked) => {
), setTheme(checked);
text: systemName, }}
}} />
// footer={{ )}
// text: '© 2021 NekoAPI', </>
// }} }
> >
<Nav.Footer collapseButton={true}></Nav.Footer> <Nav.Footer collapseButton={true}></Nav.Footer>
</Nav> </Nav>
</div>
</Layout>
</> </>
); );
}; };

View File

@@ -20,10 +20,6 @@ const SystemSetting = () => {
GitHubOAuthEnabled: '', GitHubOAuthEnabled: '',
GitHubClientId: '', GitHubClientId: '',
GitHubClientSecret: '', GitHubClientSecret: '',
LinuxDoOAuthEnabled: '',
LinuxDoClientId: '',
LinuxDoClientSecret: '',
LinuxDoMinLevel: 0,
Notice: '', Notice: '',
SMTPServer: '', SMTPServer: '',
SMTPPort: '', SMTPPort: '',
@@ -31,14 +27,15 @@ const SystemSetting = () => {
SMTPFrom: '', SMTPFrom: '',
SMTPToken: '', SMTPToken: '',
ServerAddress: '', ServerAddress: '',
OutProxyUrl: '', WorkerUrl: '',
StripeApiSecret: '', WorkerValidKey: '',
StripeWebhookSecret: '', EpayId: '',
StripePriceId: '', EpayKey: '',
PaymentEnabled: false, Price: 7.3,
StripeUnitPrice: 8.0, MinTopUp: 1,
MinTopUp: 5,
TopupGroupRatio: '', TopupGroupRatio: '',
PayAddress: '',
CustomCallbackAddress: '',
Footer: '', Footer: '',
WeChatAuthEnabled: '', WeChatAuthEnabled: '',
WeChatServerAddress: '', WeChatServerAddress: '',
@@ -48,7 +45,6 @@ const SystemSetting = () => {
TurnstileSiteKey: '', TurnstileSiteKey: '',
TurnstileSecretKey: '', TurnstileSecretKey: '',
RegisterEnabled: '', RegisterEnabled: '',
UserSelfDeletionEnabled: false,
EmailDomainRestrictionEnabled: '', EmailDomainRestrictionEnabled: '',
EmailAliasRestrictionEnabled: '', EmailAliasRestrictionEnabled: '',
SMTPSSLEnabled: '', SMTPSSLEnabled: '',
@@ -107,7 +103,6 @@ const SystemSetting = () => {
case 'PasswordRegisterEnabled': case 'PasswordRegisterEnabled':
case 'EmailVerificationEnabled': case 'EmailVerificationEnabled':
case 'GitHubOAuthEnabled': case 'GitHubOAuthEnabled':
case 'LinuxDoOAuthEnabled':
case 'WeChatAuthEnabled': case 'WeChatAuthEnabled':
case 'TelegramOAuthEnabled': case 'TelegramOAuthEnabled':
case 'TurnstileCheckEnabled': case 'TurnstileCheckEnabled':
@@ -115,8 +110,6 @@ const SystemSetting = () => {
case 'EmailAliasRestrictionEnabled': case 'EmailAliasRestrictionEnabled':
case 'SMTPSSLEnabled': case 'SMTPSSLEnabled':
case 'RegisterEnabled': case 'RegisterEnabled':
case 'UserSelfDeletionEnabled':
case 'PaymentEnabled':
value = inputs[key] === 'true' ? 'false' : 'true'; value = inputs[key] === 'true' ? 'false' : 'true';
break; break;
default: default:
@@ -131,6 +124,9 @@ const SystemSetting = () => {
if (key === 'EmailDomainWhitelist') { if (key === 'EmailDomainWhitelist') {
value = value.split(','); value = value.split(',');
} }
if (key === 'Price') {
value = parseFloat(value);
}
setInputs((inputs) => ({ setInputs((inputs) => ({
...inputs, ...inputs,
[key]: value, [key]: value,
@@ -151,17 +147,14 @@ const SystemSetting = () => {
name === 'Notice' || name === 'Notice' ||
(name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') || (name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') ||
name === 'ServerAddress' || name === 'ServerAddress' ||
name === 'OutProxyUrl' || name === 'WorkerUrl' ||
name === 'StripeApiSecret' || name === 'WorkerValidKey' ||
name === 'StripeWebhookSecret' || name === 'EpayId' ||
name === 'StripePriceId' || name === 'EpayKey' ||
name === 'StripeUnitPrice' || name === 'Price' ||
name === 'MinTopUp' || name === 'PayAddress' ||
name === 'GitHubClientId' || name === 'GitHubClientId' ||
name === 'GitHubClientSecret' || name === 'GitHubClientSecret' ||
name === 'LinuxDoClientId' ||
name === 'LinuxDoClientSecret' ||
name === 'LinuxDoMinLevel' ||
name === 'WeChatServerAddress' || name === 'WeChatServerAddress' ||
name === 'WeChatServerToken' || name === 'WeChatServerToken' ||
name === 'WeChatAccountQRCodeImageURL' || name === 'WeChatAccountQRCodeImageURL' ||
@@ -183,12 +176,15 @@ const SystemSetting = () => {
await updateOption('ServerAddress', ServerAddress); await updateOption('ServerAddress', ServerAddress);
}; };
const submitOutProxyUrl = async () => { const submitWorker = async () => {
let OutProxyUrl = removeTrailingSlash(inputs.OutProxyUrl); let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl);
await updateOption('OutProxyUrl', OutProxyUrl); await updateOption('WorkerUrl', WorkerUrl);
}; if (inputs.WorkerValidKey !== '') {
await updateOption('WorkerValidKey', inputs.WorkerValidKey);
}
}
const submitPaymentConfig = async () => { const submitPayAddress = async () => {
if (inputs.ServerAddress === '') { if (inputs.ServerAddress === '') {
showError('请先填写服务器地址'); showError('请先填写服务器地址');
return; return;
@@ -200,31 +196,15 @@ const SystemSetting = () => {
} }
await updateOption('TopupGroupRatio', inputs.TopupGroupRatio); await updateOption('TopupGroupRatio', inputs.TopupGroupRatio);
} }
let stripeApiSecret = removeTrailingSlash(inputs.StripeApiSecret); let PayAddress = removeTrailingSlash(inputs.PayAddress);
if (stripeApiSecret && !stripeApiSecret.startsWith('sk_')) { await updateOption('PayAddress', PayAddress);
showError('输入了无效的Stripe API密钥'); if (inputs.EpayId !== '') {
return; await updateOption('EpayId', inputs.EpayId);
} }
stripeApiSecret && (await updateOption('StripeApiSecret', stripeApiSecret)); if (inputs.EpayKey !== undefined && inputs.EpayKey !== '') {
await updateOption('EpayKey', inputs.EpayKey);
let stripeWebhookSecret = removeTrailingSlash(inputs.StripeWebhookSecret);
if (stripeWebhookSecret && !stripeWebhookSecret.startsWith('whsec_')) {
showError('输入了无效的Stripe Webhook签名密钥');
return;
} }
stripeWebhookSecret && await updateOption('Price', '' + inputs.Price);
(await updateOption('StripeWebhookSecret', stripeWebhookSecret));
let stripePriceId = removeTrailingSlash(inputs.StripePriceId);
if (stripePriceId && !stripePriceId.startsWith('price_')) {
showError('输入了无效的Stripe 物品价格ID');
return;
}
await updateOption('StripePriceId', stripePriceId);
await updateOption('PaymentEnable', inputs.PaymentEnabled);
await updateOption('StripeUnitPrice', inputs.StripeUnitPrice);
await updateOption('MinTopUp', inputs.MinTopUp);
}; };
const submitSMTP = async () => { const submitSMTP = async () => {
@@ -300,21 +280,6 @@ const SystemSetting = () => {
} }
}; };
const submitLinuxDoOAuth = async () => {
if (originInputs['LinuxDoClientId'] !== inputs.LinuxDoClientId) {
await updateOption('LinuxDoClientId', inputs.LinuxDoClientId);
}
if (
originInputs['LinuxDoClientSecret'] !== inputs.LinuxDoClientSecret &&
inputs.LinuxDoClientSecret !== ''
) {
await updateOption('LinuxDoClientSecret', inputs.LinuxDoClientSecret);
}
if (originInputs['LinuxDoMinLevel'] !== inputs.LinuxDoMinLevel) {
await updateOption('LinuxDoMinLevel', inputs.LinuxDoMinLevel);
}
};
const submitTelegramSettings = async () => { const submitTelegramSettings = async () => {
// await updateOption('TelegramOAuthEnabled', inputs.TelegramOAuthEnabled); // await updateOption('TelegramOAuthEnabled', inputs.TelegramOAuthEnabled);
await updateOption('TelegramBotToken', inputs.TelegramBotToken); await updateOption('TelegramBotToken', inputs.TelegramBotToken);
@@ -374,88 +339,76 @@ const SystemSetting = () => {
<Form.Button onClick={submitServerAddress}> <Form.Button onClick={submitServerAddress}>
更新服务器地址 更新服务器地址
</Form.Button> </Form.Button>
<Divider />
<Header as='h3' inverted={isDark}> <Header as='h3' inverted={isDark}>
代理设置 代理设置支持 <a href='https://github.com/Calcium-Ion/new-api-worker' target='_blank' rel='noreferrer'>new-api-worker</a>
</Header> </Header>
<Form.Group widths='equal'> <Form.Group widths='equal'>
<Form.Input <Form.Input
label='出口代理地址' label='Worker地址不填写则不启用代理'
placeholder='例如http://1.2.3.4:8888' placeholder='例如https://workername.yourdomain.workers.dev'
value={inputs.OutProxyUrl} value={inputs.WorkerUrl}
name='OutProxyUrl' name='WorkerUrl'
onChange={handleInputChange}
/>
<Form.Input
label='Worker密钥根据你部署的 Worker 填写'
placeholder='例如your_secret_key'
value={inputs.WorkerValidKey}
name='WorkerValidKey'
onChange={handleInputChange} onChange={handleInputChange}
/> />
</Form.Group> </Form.Group>
<Form.Button onClick={submitOutProxyUrl}>更新代理设置</Form.Button> <Form.Button onClick={submitWorker}>
更新Worker设置
</Form.Button>
<Divider /> <Divider />
<Header as='h3' inverted={isDark}> <Header as='h3' inverted={isDark}>
支付设置当前仅支持Stripe Checkout 支付设置当前仅支持易支付接口默认使用上方服务器地址作为回调地址
<Header.Subheader>
密钥Webhook 等设置请
<a
href='https://dashboard.stripe.com/developers'
target='_blank'
rel='noreferrer'
>
点击此处
</a>
进行设置最好先在
<a
href='https://dashboard.stripe.com/test/developers'
target='_blank'
rel='noreferrer'
>
测试环境
</a>
进行测试
</Header.Subheader>
</Header> </Header>
<Message>
Webhook
<code>{`${inputs.ServerAddress}/api/stripe/webhook`}</code>
需要包含事件<code>checkout.session.completed</code> {' '}
<code>checkout.session.expired</code>
</Message>
<Form.Group widths='equal'> <Form.Group widths='equal'>
<Form.Input <Form.Input
label='API密钥' label='支付地址,不填写则不启用在线支付'
placeholder='sk_xxx的Stripe密钥敏感信息不显示' placeholder='例如https://yourdomain.com'
value={inputs.StripeApiSecret} value={inputs.PayAddress}
name='StripeApiSecret' name='PayAddress'
onChange={handleInputChange} onChange={handleInputChange}
/> />
<Form.Input <Form.Input
label='Webhook签名密钥' label='易支付商户ID'
placeholder='whsec_xxx的Webhook签名密钥敏感信息不显示' placeholder='例如0001'
value={inputs.StripeWebhookSecret} value={inputs.EpayId}
name='StripeWebhookSecret' name='EpayId'
onChange={handleInputChange} onChange={handleInputChange}
/> />
<Form.Input <Form.Input
label='商品价格ID' label='易支付商户密钥'
placeholder='price_xxx的商品价格ID新建产品后可获得' placeholder='敏感信息不会发送到前端显示'
value={inputs.StripePriceId} value={inputs.EpayKey}
name='StripePriceId' name='EpayKey'
onChange={handleInputChange} onChange={handleInputChange}
/> />
</Form.Group> </Form.Group>
<Form.Group widths='equal'> <Form.Group widths='equal'>
<Form.Input <Form.Input
label='商品单价(元)' label='回调地址,不填写则使用上方服务器地址作为回调地址'
placeholder='商品的人民币价格' placeholder='例如https://yourdomain.com'
value={inputs.StripeUnitPrice} value={inputs.CustomCallbackAddress}
name='StripeUnitPrice' name='CustomCallbackAddress'
type={'number'} onChange={handleInputChange}
/>
<Form.Input
label='充值价格x元/美金)'
placeholder='例如7就是7元/美金'
value={inputs.Price}
name='Price'
min={0} min={0}
onChange={handleInputChange} onChange={handleInputChange}
/> />
<Form.Input <Form.Input
label='最低充值数量' label='最低充值美元数量(以美金为单位,如果使用额度请自行换算!)'
placeholder='例如2就是最低充值2件商品' placeholder='例如2就是最低充值2$'
value={inputs.MinTopUp} value={inputs.MinTopUp}
name='MinTopUp' name='MinTopUp'
type={'number'}
min={1} min={1}
onChange={handleInputChange} onChange={handleInputChange}
/> />
@@ -471,17 +424,7 @@ const SystemSetting = () => {
placeholder='为一个 JSON 文本,键为组名称,值为倍率' placeholder='为一个 JSON 文本,键为组名称,值为倍率'
/> />
</Form.Group> </Form.Group>
<Form.Group inline> <Form.Button onClick={submitPayAddress}>更新支付设置</Form.Button>
<Form.Button onClick={submitPaymentConfig}>
更新支付设置
</Form.Button>
<Form.Checkbox
checked={inputs.PaymentEnabled === 'true'}
label='开启在线支付'
name='PaymentEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Divider /> <Divider />
<Header as='h3' inverted={isDark}> <Header as='h3' inverted={isDark}>
配置登录注册 配置登录注册
@@ -540,12 +483,6 @@ const SystemSetting = () => {
name='GitHubOAuthEnabled' name='GitHubOAuthEnabled'
onChange={handleInputChange} onChange={handleInputChange}
/> />
<Form.Checkbox
checked={inputs.LinuxDoOAuthEnabled === 'true'}
label='允许通过 LINUX DO 账户登录 & 注册'
name='LinuxDoOAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox <Form.Checkbox
checked={inputs.WeChatAuthEnabled === 'true'} checked={inputs.WeChatAuthEnabled === 'true'}
label='允许通过微信登录 & 注册' label='允许通过微信登录 & 注册'
@@ -572,12 +509,6 @@ const SystemSetting = () => {
name='TurnstileCheckEnabled' name='TurnstileCheckEnabled'
onChange={handleInputChange} onChange={handleInputChange}
/> />
<Form.Checkbox
checked={inputs.UserSelfDeletionEnabled === 'true'}
label='允许用户自行删除账户'
name='UserSelfDeletionEnabled'
onChange={handleInputChange}
/>
</Form.Group> </Form.Group>
<Divider /> <Divider />
<Header as='h3' inverted={isDark}> <Header as='h3' inverted={isDark}>
@@ -746,58 +677,6 @@ const SystemSetting = () => {
保存 GitHub OAuth 设置 保存 GitHub OAuth 设置
</Form.Button> </Form.Button>
<Divider /> <Divider />
<Header as='h3'>
配置 LINUX DO Oauth
<Header.Subheader>
用以支持通过 LINUX DO 进行登录注册
<a
href='https://connect.linux.do'
target='_blank'
rel='noreferrer'
>
点击此处
</a>
管理你的 LINUX DO OAuth
</Header.Subheader>
</Header>
<Message>
Homepage URL <code>{inputs.ServerAddress}</code>
Authorization callback URL {' '}
<code>{`${inputs.ServerAddress}/oauth/linuxdo`}</code>
</Message>
<Form.Group widths={3}>
<Form.Input
label='LINUX DO Client ID'
name='LinuxDoClientId'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.LinuxDoClientId}
placeholder='输入你注册的 LINUX DO OAuth 的 ID'
/>
<Form.Input
label='LINUX DO Client Secret'
name='LinuxDoClientSecret'
onChange={handleInputChange}
type='password'
autoComplete='new-password'
value={inputs.LinuxDoClientSecret}
placeholder='敏感信息不会发送到前端显示'
/>
<Form.Input
label='限制最低信任等级'
name='LinuxDoMinLevel'
onChange={handleInputChange}
type='number'
min={0}
max={4}
value={inputs.LinuxDoMinLevel}
placeholder='输入允许使用的最低 LINUX DO 信任等级'
/>
</Form.Group>
<Form.Button onClick={submitLinuxDoOAuth}>
保存 LINUX DO OAuth 设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}> <Header as='h3' inverted={isDark}>
配置 WeChat Server 配置 WeChat Server
<Header.Subheader> <Header.Subheader>

Some files were not shown because too many files have changed in this diff Show More