mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-17 19:13:42 +08:00
Compare commits
125 Commits
v0.2.1.0-a
...
v0.2.6-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
811019bf5c | ||
|
|
0ed6600437 | ||
|
|
1cff3c100a | ||
|
|
d7a343e2f6 | ||
|
|
637801fba5 | ||
|
|
2bf404507f | ||
|
|
675de89c69 | ||
|
|
16b9aacb06 | ||
|
|
cad380eb16 | ||
|
|
234e39ddeb | ||
|
|
7fb6420e66 | ||
|
|
5425b5bfc3 | ||
|
|
21f32605c8 | ||
|
|
1fe7f14d57 | ||
|
|
a7bafec1bf | ||
|
|
1c6fd87909 | ||
|
|
d1c8947851 | ||
|
|
7d2d525051 | ||
|
|
be4809b95a | ||
|
|
e2edd5e7e5 | ||
|
|
a14fa1adb1 | ||
|
|
ed951b3974 | ||
|
|
2cb10b003a | ||
|
|
86b17fcce8 | ||
|
|
08b5336431 | ||
|
|
20aaf30785 | ||
|
|
bfcaccc2e3 | ||
|
|
3f448ba4fc | ||
|
|
408c2bdd9b | ||
|
|
b1b38a6bd4 | ||
|
|
608ec28761 | ||
|
|
a3ccc92f55 | ||
|
|
77e7d11151 | ||
|
|
783e8fd74a | ||
|
|
2841669246 | ||
|
|
89ebd85503 | ||
|
|
1a39ef74ce | ||
|
|
53e8790024 | ||
|
|
9294127686 | ||
|
|
6b97842f78 | ||
|
|
bdc65bdba2 | ||
|
|
76dc7af8d1 | ||
|
|
892b7d1ad4 | ||
|
|
6b71db7ce2 | ||
|
|
b8fb351fd8 | ||
|
|
79cf70683f | ||
|
|
e6765ef32d | ||
|
|
4ef98ba7eb | ||
|
|
65b85377c6 | ||
|
|
c6e85d5b57 | ||
|
|
1162683b4d | ||
|
|
818bd824da | ||
|
|
c74e43b8fd | ||
|
|
6e54f01435 | ||
|
|
505916b755 | ||
|
|
a4defe6ada | ||
|
|
9dfd405ba9 | ||
|
|
6c5b94ceb0 | ||
|
|
ac2984315a | ||
|
|
848358d876 | ||
|
|
e9abe5b705 | ||
|
|
d7e117acf5 | ||
|
|
1456992aae | ||
|
|
3b6ea51033 | ||
|
|
21250a46a6 | ||
|
|
b31fadd74f | ||
|
|
300947f400 | ||
|
|
bf94893f6a | ||
|
|
97af77b26c | ||
|
|
4ef2422b97 | ||
|
|
f188147680 | ||
|
|
03bd9b0cc4 | ||
|
|
149902bd8a | ||
|
|
08e10df887 | ||
|
|
0a49715c3d | ||
|
|
89efed48fc | ||
|
|
97e0aae0a7 | ||
|
|
320da09f36 | ||
|
|
2d849e0dd6 | ||
|
|
60d7ed3fb5 | ||
|
|
c5f6d0e063 | ||
|
|
a7cfce24d0 | ||
|
|
34bf8f8945 | ||
|
|
2d1d1b4631 | ||
|
|
5961de03e7 | ||
|
|
fbdb17022c | ||
|
|
497cc32634 | ||
|
|
462c328d4b | ||
|
|
cd3ed22045 | ||
|
|
2329d387ca | ||
|
|
257cfc2390 | ||
|
|
7d18a8e2a9 | ||
|
|
fed1a1d6a3 | ||
|
|
fc9f8c8e8a | ||
|
|
f3f36dafbd | ||
|
|
aaf3a1f07b | ||
|
|
c040fa229d | ||
|
|
1cd1e54be4 | ||
|
|
3db64afc7f | ||
|
|
bc9cfa5da0 | ||
|
|
660b9b3c99 | ||
|
|
80af3718d0 | ||
|
|
77ea6bec46 | ||
|
|
c0ab8ae953 | ||
|
|
923c2dee32 | ||
|
|
ea17a46d8e | ||
|
|
bfe9e5d25a | ||
|
|
831ff47254 | ||
|
|
11eaba6b5d | ||
|
|
c2e4ec25c8 | ||
|
|
8537f10412 | ||
|
|
71d60eeef7 | ||
|
|
247ae0988f | ||
|
|
0907fa6994 | ||
|
|
9855343aa8 | ||
|
|
bd50fde268 | ||
|
|
4267de5642 | ||
|
|
8b55116563 | ||
|
|
f35e63e3f3 | ||
|
|
17c409de23 | ||
|
|
e4753e7411 | ||
|
|
9adefa80b9 | ||
|
|
4ce2381182 | ||
|
|
62afc21ea5 | ||
|
|
7ddb7c586d |
6
.github/ISSUE_TEMPLATE/config.yml
vendored
6
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: 项目群聊
|
||||
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
|
||||
about: QQ 群:629454374
|
||||
- name: 交流社区
|
||||
url: https://linux.do
|
||||
about: 项目交流社区
|
||||
|
||||
5
.github/workflows/docker-image-amd64.yml
vendored
5
.github/workflows/docker-image-amd64.yml
vendored
@@ -1,9 +1,6 @@
|
||||
name: Publish Docker image (amd64)
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
@@ -42,7 +39,7 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
pengzhile/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
|
||||
2
.github/workflows/docker-image-arm64.yml
vendored
2
.github/workflows/docker-image-arm64.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
pengzhile/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
|
||||
@@ -7,7 +7,7 @@ all: build-frontend start-backend
|
||||
|
||||
build-frontend:
|
||||
@echo "Building frontend..."
|
||||
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build npm run build
|
||||
@cd $(FRONTEND_DIR) && yarn install --network-timeout 1000000 && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) yarn build
|
||||
|
||||
start-backend:
|
||||
@echo "Starting backend dev server..."
|
||||
@@ -60,8 +60,8 @@
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
|
||||
## 渠道重试
|
||||
渠道重试功能已经实现,可以在渠道管理中设置重试次数,需要开启缓存功能,否则只会使用同优先级重试。
|
||||
如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
|
||||
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。
|
||||
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
|
||||
### 缓存设置方法
|
||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
|
||||
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
|
||||
|
||||
@@ -11,12 +11,12 @@ import (
|
||||
|
||||
// Pay Settings
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
var EpayId = ""
|
||||
var EpayKey = ""
|
||||
var Price = 7.3
|
||||
var MinTopUp = 1
|
||||
var StripeApiSecret = ""
|
||||
var StripeWebhookSecret = ""
|
||||
var StripePriceId = ""
|
||||
var PaymentEnabled = false
|
||||
var StripeUnitPrice = 8.0
|
||||
var MinTopUp = 5
|
||||
|
||||
var StartTime = time.Now().Unix() // unit: second
|
||||
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
||||
@@ -50,12 +50,15 @@ var PasswordLoginEnabled = true
|
||||
var PasswordRegisterEnabled = true
|
||||
var EmailVerificationEnabled = false
|
||||
var GitHubOAuthEnabled = false
|
||||
var LinuxDoOAuthEnabled = false
|
||||
var WeChatAuthEnabled = false
|
||||
var TelegramOAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
var RegisterEnabled = true
|
||||
var UserSelfDeletionEnabled = false
|
||||
|
||||
var EmailDomainRestrictionEnabled = false
|
||||
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
|
||||
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
|
||||
var EmailDomainWhitelist = []string{
|
||||
"gmail.com",
|
||||
"163.com",
|
||||
@@ -83,6 +86,10 @@ var SMTPToken = ""
|
||||
var GitHubClientId = ""
|
||||
var GitHubClientSecret = ""
|
||||
|
||||
var LinuxDoClientId = ""
|
||||
var LinuxDoClientSecret = ""
|
||||
var LinuxDoMinLevel = 0
|
||||
|
||||
var WeChatServerAddress = ""
|
||||
var WeChatServerToken = ""
|
||||
var WeChatAccountQRCodeImageURL = ""
|
||||
@@ -111,7 +118,7 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
|
||||
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
|
||||
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
||||
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
@@ -184,6 +191,12 @@ const (
|
||||
ChannelStatusAutoDisabled = 3
|
||||
)
|
||||
|
||||
const (
|
||||
TopUpStatusPending = "pending"
|
||||
TopUpStatusSuccess = "success"
|
||||
TopUpStatusExpired = "expired"
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
@@ -214,6 +227,8 @@ const (
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
@@ -249,4 +264,7 @@ var ChannelBaseURLs = []string{
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
"https://api.cohere.ai", //34
|
||||
}
|
||||
|
||||
@@ -16,7 +16,22 @@ func SafeGoroutine(f func()) {
|
||||
}()
|
||||
}
|
||||
|
||||
func SafeSend(ch chan bool, value bool) (closed bool) {
|
||||
func SafeSendBool(ch chan bool, value bool) (closed bool) {
|
||||
defer func() {
|
||||
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||
if recover() != nil {
|
||||
closed = true
|
||||
}
|
||||
}()
|
||||
|
||||
// This will panic if the channel is closed.
|
||||
ch <- value
|
||||
|
||||
// If the code reaches here, then the channel was not closed.
|
||||
return false
|
||||
}
|
||||
|
||||
func SafeSendString(ch chan string, value string) (closed bool) {
|
||||
defer func() {
|
||||
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||
if recover() != nil {
|
||||
|
||||
84
common/hash.go
Normal file
84
common/hash.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Sha256Raw(data string) []byte {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(data))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1Raw(data []byte) []byte {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(data))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1(data string) string {
|
||||
return hex.EncodeToString(Sha1Raw([]byte(data)))
|
||||
}
|
||||
|
||||
func HmacSha256Raw(message, key []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(message)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func HmacSha256(message, key string) string {
|
||||
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
|
||||
}
|
||||
|
||||
func RandomBytes(length int) []byte {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
b := make([]byte, length)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func RandomString(length int) string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, length)
|
||||
randomBytes := RandomBytes(length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomHex(length int) string {
|
||||
const chars = "abcdef0123456789"
|
||||
result := make([]byte, length)
|
||||
randomBytes := RandomBytes(length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomNumber(length int) string {
|
||||
const chars = "0123456789"
|
||||
result := make([]byte, length)
|
||||
randomBytes := RandomBytes(length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomUUID() string {
|
||||
all := RandomHex(32)
|
||||
return all[:8] + "-" + all[8:12] + "-" + all[12:16] + "-" + all[16:20] + "-" + all[20:]
|
||||
}
|
||||
@@ -98,3 +98,11 @@ func LogQuota(quota int) string {
|
||||
return fmt.Sprintf("%d 点额度", quota)
|
||||
}
|
||||
}
|
||||
|
||||
func LogQuotaF(quota float64) string {
|
||||
if DisplayInCurrencyEnabled {
|
||||
return fmt.Sprintf("$%.6f 额度", quota/QuotaPerUnit)
|
||||
} else {
|
||||
return fmt.Sprintf("%d 点额度", int64(quota))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
// TODO: when a new api is enabled, check the pricing here
|
||||
// 1 === $0.002 / 1K tokens
|
||||
// 1 === ¥0.014 / 1k tokens
|
||||
|
||||
var DefaultModelRatio = map[string]float64{
|
||||
//"midjourney": 50,
|
||||
"gpt-4-gizmo-*": 15,
|
||||
@@ -21,12 +22,14 @@ var DefaultModelRatio = map[string]float64{
|
||||
"gpt-4-32k": 30,
|
||||
"gpt-4-32k-0314": 30,
|
||||
"gpt-4-32k-0613": 30,
|
||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
|
||||
"gpt-3.5-turbo-0301": 0.75,
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||
@@ -71,11 +74,14 @@ var DefaultModelRatio = map[string]float64{
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-1.0-pro-vision-001": 1,
|
||||
"gemini-1.0-pro-001": 1,
|
||||
"gemini-1.5-pro": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-1.0-pro-vision-001": 1,
|
||||
"gemini-1.0-pro-001": 1,
|
||||
"gemini-1.5-pro-latest": 1,
|
||||
"gemini-1.0-pro-latest": 1,
|
||||
"gemini-1.0-pro-vision-latest": 1,
|
||||
"gemini-ultra": 1,
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
@@ -100,6 +106,14 @@ var DefaultModelRatio = map[string]float64{
|
||||
"yi-34b-chat-0205": 0.018,
|
||||
"yi-34b-chat-200k": 0.0864,
|
||||
"yi-vl-plus": 0.0432,
|
||||
"command": 0.5,
|
||||
"command-nightly": 0.5,
|
||||
"command-light": 0.5,
|
||||
"command-light-nightly": 0.5,
|
||||
"command-r": 0.25,
|
||||
"command-r-plus ": 1.5,
|
||||
"deepseek-chat": 0.07,
|
||||
"deepseek-coder": 0.07,
|
||||
}
|
||||
|
||||
var DefaultModelPrice = map[string]float64{
|
||||
@@ -190,18 +204,20 @@ func GetModelRatio(name string) float64 {
|
||||
|
||||
func GetCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-3.5") {
|
||||
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
|
||||
// https://openai.com/blog/new-embedding-models-and-api-updates
|
||||
// Updated GPT-3.5 Turbo model and lower pricing
|
||||
if strings.HasSuffix(name, "0125") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasSuffix(name, "1106") {
|
||||
return 2
|
||||
}
|
||||
return 4.0 / 3.0
|
||||
if name == "gpt-3.5-turbo" {
|
||||
return 3
|
||||
}
|
||||
|
||||
return 1.333333
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") {
|
||||
if strings.HasSuffix(name, "preview") {
|
||||
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") {
|
||||
return 3
|
||||
}
|
||||
return 2
|
||||
@@ -219,6 +235,19 @@ func GetCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gemini-") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "command") {
|
||||
switch name {
|
||||
case "command-r":
|
||||
return 3
|
||||
case "command-r-plus":
|
||||
return 5
|
||||
default:
|
||||
return 2
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(name, "deepseek") {
|
||||
return 2
|
||||
}
|
||||
switch name {
|
||||
case "llama2-70b-4096":
|
||||
return 0.8 / 0.7
|
||||
|
||||
@@ -236,3 +236,8 @@ func StringToByteSlice(s string) []byte {
|
||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||
}
|
||||
|
||||
func RandomSleep() {
|
||||
// Sleep for 0-3000 ms
|
||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package constant
|
||||
|
||||
var MjNotifyEnabled = false
|
||||
var MjAccountFilterEnabled = false
|
||||
var MjModeClearEnabled = false
|
||||
var MjForwardUrlEnabled = true
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
|
||||
8
constant/payment.go
Normal file
8
constant/payment.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package constant
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
var EpayId = ""
|
||||
var EpayKey = ""
|
||||
var Price = 7.3
|
||||
var MinTopUp = 1
|
||||
@@ -27,7 +27,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
}
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = &http.Request{
|
||||
@@ -60,12 +59,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
if testModel == "" {
|
||||
testModel = adaptor.GetModelList()[0]
|
||||
meta.UpstreamModelName = testModel
|
||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||
testModel = *channel.TestModel
|
||||
} else {
|
||||
testModel = adaptor.GetModelList()[0]
|
||||
}
|
||||
}
|
||||
request := buildTestRequest()
|
||||
request.Model = testModel
|
||||
meta.UpstreamModelName = testModel
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||
|
||||
adaptor.Init(meta, *request)
|
||||
|
||||
@@ -83,7 +86,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
err := relaycommon.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
||||
}
|
||||
|
||||
@@ -123,6 +123,8 @@ func GitHubOAuth(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
|
||||
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if githubUser.Name != "" {
|
||||
user.DisplayName = githubUser.Name
|
||||
@@ -133,7 +135,7 @@ func GitHubOAuth(c *gin.Context) {
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(0); err != nil {
|
||||
if err := user.Insert(user.InviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
|
||||
239
controller/linuxdo.go
Normal file
239
controller/linuxdo.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LinuxDoOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type LinuxDoUser struct {
|
||||
ID int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Authorization", "Basic "+auth)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse LinuxDoOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
var linuxdoUser LinuxDoUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if linuxdoUser.ID == 0 {
|
||||
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||
}
|
||||
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
|
||||
return nil, errors.New("用户 LINUX DO 信任等级不足!")
|
||||
}
|
||||
return &linuxdoUser, nil
|
||||
}
|
||||
|
||||
func LinuxDoOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LinuxDoBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.LinuxDoOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 LINUX DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
|
||||
LinuxDoLevel: linuxdoUser.TrustLevel,
|
||||
}
|
||||
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
|
||||
err := user.FillUserByLinuxDoId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDoLevel = linuxdoUser.TrustLevel
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
affCode := c.Query("aff")
|
||||
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
|
||||
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if linuxdoUser.Name != "" {
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
} else {
|
||||
user.DisplayName = linuxdoUser.Username
|
||||
}
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(user.InviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func LinuxDoBind(c *gin.Context) {
|
||||
if !common.LinuxDoOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 LINUX DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
|
||||
LinuxDoLevel: linuxdoUser.TrustLevel,
|
||||
}
|
||||
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 LINUX DO 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
|
||||
user.LinuxDoLevel = linuxdoUser.TrustLevel
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -24,7 +24,7 @@ func GetAllLogs(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -35,6 +35,7 @@ func GetAllLogs(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"total": total,
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
@@ -58,7 +59,7 @@ func GetUserLogs(c *gin.Context) {
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -69,6 +70,7 @@ func GetUserLogs(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"total": total,
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
|
||||
@@ -10,11 +10,11 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -86,7 +86,7 @@ func UpdateMidjourneyTaskBulk() {
|
||||
continue
|
||||
}
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 5
|
||||
timeout := time.Second * 15
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
@@ -233,6 +233,12 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
if constant.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -259,7 +265,7 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
if !strings.Contains(common.ServerAddress, "localhost") {
|
||||
if constant.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
|
||||
@@ -38,6 +38,8 @@ func GetStatus(c *gin.Context) {
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDoClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
@@ -46,7 +48,7 @@ func GetStatus(c *gin.Context) {
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": common.ServerAddress,
|
||||
"price": common.Price,
|
||||
"stripe_unit_price": common.StripeUnitPrice,
|
||||
"min_topup": common.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
@@ -60,7 +62,7 @@ func GetStatus(c *gin.Context) {
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": common.PayAddress != "" && common.EpayId != "" && common.EpayKey != "",
|
||||
"payment_enabled": common.PaymentEnabled,
|
||||
"mj_notify_enabled": constant.MjNotifyEnabled,
|
||||
},
|
||||
})
|
||||
@@ -120,12 +122,17 @@ func SendEmailVerification(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的邮箱地址",
|
||||
})
|
||||
return
|
||||
}
|
||||
localPart := parts[0]
|
||||
domainPart := parts[1]
|
||||
if common.EmailDomainRestrictionEnabled {
|
||||
parts := strings.Split(email, "@")
|
||||
localPart := parts[0]
|
||||
domainPart := parts[1]
|
||||
|
||||
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1
|
||||
allowed := false
|
||||
for _, domain := range common.EmailDomainWhitelist {
|
||||
if domainPart == domain {
|
||||
@@ -133,13 +140,7 @@ func SendEmailVerification(c *gin.Context) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if allowed && !containsSpecialSymbols {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Your email address is allowed.",
|
||||
})
|
||||
return
|
||||
} else {
|
||||
if !allowed {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
|
||||
@@ -147,6 +148,17 @@ func SendEmailVerification(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if common.EmailAliasRestrictionEnabled {
|
||||
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1
|
||||
if containsSpecialSymbols {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if model.IsEmailAlreadyTaken(email) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -14,7 +14,7 @@ func GetOptions(c *gin.Context) {
|
||||
var options []*model.Option
|
||||
common.OptionMapRWMutex.Lock()
|
||||
for k, v := range common.OptionMap {
|
||||
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
|
||||
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") {
|
||||
continue
|
||||
}
|
||||
options = append(options, &model.Option{
|
||||
@@ -50,6 +50,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "LinuxDoOAuthEnabled":
|
||||
if option.Value == "true" && common.LinuxDoClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 LINUX DO OAuth,请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
@@ -42,7 +43,7 @@ func Relay(c *gin.Context) {
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
openaiErr := relayHandler(c, relayMode)
|
||||
retryLogStr := fmt.Sprintf("重试:%d", channelId)
|
||||
useChannel := []int{channelId}
|
||||
if openaiErr != nil {
|
||||
go processChannelError(c, channelId, openaiErr)
|
||||
} else {
|
||||
@@ -55,7 +56,7 @@ func Relay(c *gin.Context) {
|
||||
break
|
||||
}
|
||||
channelId = channel.Id
|
||||
retryLogStr += fmt.Sprintf("->%d", channel.Id)
|
||||
useChannel = append(useChannel, channelId)
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
@@ -66,7 +67,10 @@ func Relay(c *gin.Context) {
|
||||
go processChannelError(c, channelId, openaiErr)
|
||||
}
|
||||
}
|
||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
@@ -92,12 +96,23 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode == 307 {
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode/100 == 5 {
|
||||
// 超时不重试
|
||||
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode == 408 {
|
||||
// azure处理超时不重试
|
||||
return false
|
||||
}
|
||||
if openaiErr.LocalError {
|
||||
return false
|
||||
}
|
||||
@@ -109,7 +124,7 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
||||
|
||||
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
|
||||
channelName := c.GetString("channel_name")
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
@@ -145,7 +160,7 @@ func RelayMidjourney(c *gin.Context) {
|
||||
"code": err.Code,
|
||||
})
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
97
controller/stripe.go
Normal file
97
controller/stripe.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stripe/stripe-go/v76"
|
||||
"github.com/stripe/stripe-go/v76/webhook"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func StripeWebhook(c *gin.Context) {
|
||||
payload, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
signature := c.GetHeader("Stripe-Signature")
|
||||
endpointSecret := common.StripeWebhookSecret
|
||||
event, err := webhook.ConstructEvent(payload, signature, endpointSecret)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case stripe.EventTypeCheckoutSessionCompleted:
|
||||
sessionCompleted(event)
|
||||
case stripe.EventTypeCheckoutSessionExpired:
|
||||
sessionExpired(event)
|
||||
default:
|
||||
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func sessionCompleted(event stripe.Event) {
|
||||
customerId := event.GetObjectValue("customer")
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "complete" != status {
|
||||
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
err := model.Recharge(referenceId, customerId)
|
||||
if err != nil {
|
||||
log.Println(err.Error(), referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
||||
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
||||
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
||||
}
|
||||
|
||||
func sessionExpired(event stripe.Event) {
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "expired" != status {
|
||||
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if "" == referenceId {
|
||||
log.Println("未提供支付单号")
|
||||
return
|
||||
}
|
||||
|
||||
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||
if topUp == nil {
|
||||
log.Println("充值订单不存在", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
log.Println("充值订单状态错误", referenceId)
|
||||
}
|
||||
|
||||
topUp.Status = common.TopUpStatusExpired
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("充值订单已过期", referenceId)
|
||||
}
|
||||
@@ -1,21 +1,20 @@
|
||||
package controller
|
||||
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
epay "github.com/star-horizon/go-epay"
|
||||
"github.com/stripe/stripe-go/v76"
|
||||
"github.com/stripe/stripe-go/v76/checkout/session"
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"sync"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EpayRequest struct {
|
||||
type PayRequest struct {
|
||||
Amount int `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
@@ -26,177 +25,114 @@ type AmountRequest struct {
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if common.PayAddress == "" || common.EpayId == "" || common.EpayKey == "" {
|
||||
return nil
|
||||
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
|
||||
if !strings.HasPrefix(common.StripeApiSecret, "sk_") {
|
||||
return "", fmt.Errorf("无效的Stripe API密钥")
|
||||
}
|
||||
withUrl, err := epay.NewClientWithUrl(&epay.Config{
|
||||
PartnerID: common.EpayId,
|
||||
Key: common.EpayKey,
|
||||
}, common.PayAddress)
|
||||
|
||||
stripe.Key = common.StripeApiSecret
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
ClientReferenceID: stripe.String(referenceId),
|
||||
SuccessURL: stripe.String(common.ServerAddress + "/log"),
|
||||
CancelURL: stripe.String(common.ServerAddress + "/topup"),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(common.StripePriceId),
|
||||
Quantity: stripe.Int64(amount),
|
||||
},
|
||||
},
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||
}
|
||||
|
||||
if "" == customerId {
|
||||
if "" != email {
|
||||
params.CustomerEmail = stripe.String(email)
|
||||
}
|
||||
|
||||
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
|
||||
} else {
|
||||
params.Customer = stripe.String(customerId)
|
||||
}
|
||||
|
||||
result, err := session.New(params)
|
||||
if err != nil {
|
||||
return nil
|
||||
return "", err
|
||||
}
|
||||
return withUrl
|
||||
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func GetAmount(count float64, user model.User) float64 {
|
||||
// 别问为什么用float64,问就是这么点钱没必要
|
||||
topupGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
amount := count * common.Price * topupGroupRatio
|
||||
return amount
|
||||
func GetPayAmount(count float64) float64 {
|
||||
return count * common.StripeUnitPrice
|
||||
}
|
||||
|
||||
func RequestEpay(c *gin.Context) {
|
||||
var req EpayRequest
|
||||
func GetChargedAmount(count float64, user model.User) float64 {
|
||||
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||
if topUpGroupRatio == 0 {
|
||||
topUpGroupRatio = 1
|
||||
}
|
||||
|
||||
return count * topUpGroupRatio
|
||||
}
|
||||
|
||||
func RequestPayLink(c *gin.Context) {
|
||||
var req PayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
|
||||
return
|
||||
}
|
||||
if !common.PaymentEnabled {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
|
||||
return
|
||||
}
|
||||
if req.PaymentMethod != "stripe" {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||
return
|
||||
}
|
||||
if req.Amount < common.MinTopUp {
|
||||
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
|
||||
return
|
||||
}
|
||||
if req.Amount > 10000 {
|
||||
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
payMoney := GetAmount(float64(req.Amount), *user)
|
||||
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||
|
||||
var payType epay.PurchaseType
|
||||
if req.PaymentMethod == "zfb" {
|
||||
payType = epay.Alipay
|
||||
}
|
||||
if req.PaymentMethod == "wx" {
|
||||
req.PaymentMethod = "wxpay"
|
||||
payType = epay.WechatPay
|
||||
}
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(common.ServerAddress + "/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
|
||||
return
|
||||
}
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
Type: payType,
|
||||
ServiceTradeNo: "A" + tradeNo,
|
||||
Name: "B" + tradeNo,
|
||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
Device: epay.PC,
|
||||
NotifyUrl: notifyUrl,
|
||||
ReturnUrl: returnUrl,
|
||||
})
|
||||
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), common.RandomString(4))
|
||||
referenceId := "ref_" + common.Sha1(reference)
|
||||
|
||||
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, int64(req.Amount))
|
||||
if err != nil {
|
||||
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: payMoney,
|
||||
TradeNo: "A" + tradeNo,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: "pending",
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
|
||||
}
|
||||
|
||||
// tradeNo lock
|
||||
var orderLocks sync.Map
|
||||
var createLock sync.Mutex
|
||||
|
||||
// LockOrder 尝试对给定订单号加锁
|
||||
func LockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
createLock.Lock()
|
||||
defer createLock.Unlock()
|
||||
lock, ok = orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
lock = new(sync.Mutex)
|
||||
orderLocks.Store(tradeNo, lock)
|
||||
}
|
||||
}
|
||||
lock.(*sync.Mutex).Lock()
|
||||
}
|
||||
|
||||
// UnlockOrder 释放给定订单号的锁
|
||||
func UnlockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if ok {
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func EpayNotify(c *gin.Context) {
|
||||
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
|
||||
r[t] = c.Request.URL.Query().Get(t)
|
||||
return r
|
||||
}, map[string]string{})
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
log.Println("易支付回调失败 未找到配置信息")
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
return
|
||||
}
|
||||
}
|
||||
verifyInfo, err := client.Verify(params)
|
||||
if err == nil && verifyInfo.VerifyStatus {
|
||||
_, err := c.Writer.Write([]byte("success"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
} else {
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
log.Println("易支付回调签名验证失败")
|
||||
return
|
||||
}
|
||||
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
log.Println(verifyInfo)
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
|
||||
if topUp == nil {
|
||||
log.Printf("易支付回调未找到订单: %v", verifyInfo)
|
||||
return
|
||||
}
|
||||
if topUp.Status == "pending" {
|
||||
topUp.Status = "success"
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新订单失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
||||
//user.Quota += topUp.Amount * 500000
|
||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000)
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*500000), topUp.Money))
|
||||
}
|
||||
} else {
|
||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"payLink": payLink,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func RequestAmount(c *gin.Context) {
|
||||
@@ -206,12 +142,23 @@ func RequestAmount(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
if !common.PaymentEnabled {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
|
||||
return
|
||||
}
|
||||
if req.Amount < common.MinTopUp {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
payMoney := GetAmount(float64(req.Amount), *user)
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
payMoney := GetPayAmount(float64(req.Amount))
|
||||
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"payAmount": strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
"chargedAmount": strconv.FormatFloat(chargedMoney, 'f', 2, 64),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -65,6 +66,7 @@ func setupLogin(user *model.User, c *gin.Context) {
|
||||
session.Set("username", user.Username)
|
||||
session.Set("role", user.Role)
|
||||
session.Set("status", user.Status)
|
||||
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -215,7 +217,8 @@ func GetAllUsers(c *gin.Context) {
|
||||
|
||||
func SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
users, err := model.SearchUsers(keyword)
|
||||
group := c.Query("group")
|
||||
users, err := model.SearchUsers(keyword, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -451,7 +454,7 @@ func UpdateUser(c *gin.Context) {
|
||||
updatedUser.Password = "" // rollback to what it should be
|
||||
}
|
||||
updatePassword := updatedUser.Password != ""
|
||||
if err := updatedUser.Update(updatePassword); err != nil {
|
||||
if err := updatedUser.Edit(updatePassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
@@ -515,7 +518,7 @@ func UpdateSelf(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteUser(c *gin.Context) {
|
||||
func HardDeleteUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -524,7 +527,7 @@ func DeleteUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
originUser, err := model.GetUserById(id, false)
|
||||
originUser, err := model.GetUserByIdUnscoped(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -548,9 +551,23 @@ func DeleteUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteSelf(c *gin.Context) {
|
||||
if !common.UserSelfDeletionEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "当前设置不允许用户自我删除账号",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
|
||||
@@ -724,7 +741,7 @@ func ManageUser(c *gin.Context) {
|
||||
user.Role = common.RoleCommonUser
|
||||
}
|
||||
|
||||
if err := user.UpdateAll(false); err != nil {
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
@@ -789,7 +806,11 @@ type topUpRequest struct {
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
var lock = sync.Mutex{}
|
||||
|
||||
func TopUp(c *gin.Context) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
req := topUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,18 +2,17 @@ version: '3.4'
|
||||
|
||||
services:
|
||||
new-api:
|
||||
image: calciumion/new-api:latest
|
||||
# build: .
|
||||
image: pengzhile/new-api:latest
|
||||
container_name: new-api
|
||||
restart: always
|
||||
command: --log-dir /app/logs
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- ./data:/data
|
||||
- ./data/new-api:/data
|
||||
- ./logs:/app/logs
|
||||
environment:
|
||||
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- SESSION_SECRET=random_string # 修改为随机字符串
|
||||
- TZ=Asia/Shanghai
|
||||
@@ -23,13 +22,22 @@ services:
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
healthcheck:
|
||||
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
- db
|
||||
|
||||
redis:
|
||||
image: redis:latest
|
||||
container_name: redis
|
||||
restart: always
|
||||
|
||||
db:
|
||||
image: mysql:8.2.0
|
||||
container_name: mysql
|
||||
restart: always
|
||||
volumes:
|
||||
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
|
||||
environment:
|
||||
TZ: Asia/Shanghai # 设置时区
|
||||
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
|
||||
MYSQL_USER: newapi # 创建专用用户
|
||||
MYSQL_PASSWORD: '123456' # 设置专用用户密码
|
||||
MYSQL_DATABASE: new-api # 自动创建数据库
|
||||
@@ -32,6 +32,21 @@ type GeneralOpenAIRequest struct {
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAITools struct {
|
||||
Type string `json:"type"`
|
||||
Function OpenAIFunction `json:"function"`
|
||||
}
|
||||
|
||||
type OpenAIFunction struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
||||
return int64(r.MaxTokens)
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
|
||||
@@ -54,21 +54,54 @@ type OpenAIEmbeddingResponse struct {
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls any `json:"tool_calls,omitempty"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
|
||||
Logprobs *any `json:"logprobs"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoiceDelta struct {
|
||||
Content *string `json:"content,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
|
||||
return c.Content == nil && len(c.ToolCalls) == 0
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
||||
c.Content = &s
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
|
||||
if c.Content == nil {
|
||||
return ""
|
||||
}
|
||||
return *c.Content
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
// Index is not nil only in chat completion chunk object
|
||||
Index *int `json:"index,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Type any `json:"type"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
// call function with arguments in JSON format
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint *string `json:"system_fingerprint"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseSimple struct {
|
||||
|
||||
18
go.mod
18
go.mod
@@ -5,6 +5,9 @@ go 1.18
|
||||
|
||||
require (
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
||||
github.com/aws/aws-sdk-go-v2 v1.26.1
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||
github.com/gin-contrib/cors v1.4.0
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
@@ -15,10 +18,12 @@ require (
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.4.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pkoukk/tiktoken-go v0.1.6
|
||||
github.com/samber/lo v1.38.1
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
|
||||
github.com/stripe/stripe-go/v76 v76.21.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/image v0.15.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
@@ -29,6 +34,10 @@ require (
|
||||
|
||||
require (
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||
github.com/aws/smithy-go v1.20.2 // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
@@ -55,7 +64,6 @@ require (
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
@@ -65,9 +73,9 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
|
||||
41
go.sum
41
go.sum
@@ -2,6 +2,20 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
||||
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
|
||||
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
||||
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
||||
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
@@ -62,8 +76,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -83,6 +97,8 @@ github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
|
||||
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
|
||||
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
@@ -113,8 +129,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -128,6 +142,8 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
|
||||
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@@ -135,12 +151,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
|
||||
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
|
||||
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2/go.mod h1:SiffGCWGGMVwujne2dUQbJ5zUVD1V1Yj0hDuTfqFNEo=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -153,6 +167,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stripe/stripe-go/v76 v76.21.0 h1:O3GHImHS4oUI3qWMOClHN3zAQF5/oswS/NB7leV1fsU=
|
||||
github.com/stripe/stripe-go/v76 v76.21.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
@@ -173,18 +189,20 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -202,7 +220,6 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
|
||||
@@ -15,6 +15,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
role := session.Get("role")
|
||||
id := session.Get("id")
|
||||
status := session.Get("status")
|
||||
linuxDoEnable := session.Get("linuxdo_enable")
|
||||
if username == nil {
|
||||
// Check access token
|
||||
accessToken := c.Request.Header.Get("Authorization")
|
||||
@@ -33,6 +34,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
role = user.Role
|
||||
id = user.Id
|
||||
status = user.Status
|
||||
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -50,6 +52,14 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户 LINUX DO 信任等级不足",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if role.(int) < minRole {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -112,6 +122,15 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !linuxDoEnabled {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
|
||||
return
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
|
||||
@@ -24,6 +24,9 @@ func Distribute() func(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
c.Set("group", userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -40,72 +43,7 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
shouldSelectChannel := true
|
||||
// Select a channel for the user
|
||||
var modelRequest ModelRequest
|
||||
var err error
|
||||
if strings.Contains(c.Request.URL.Path, "/mj/") {
|
||||
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
|
||||
shouldSelectChannel = false
|
||||
} else {
|
||||
midjourneyRequest := dto.MidjourneyRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
|
||||
if err != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
|
||||
return
|
||||
}
|
||||
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
||||
if mjErr != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
|
||||
return
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
if !success {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||
return
|
||||
} else {
|
||||
// task fetch, task fetch by condition, notify
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
}
|
||||
modelRequest.Model = midjourneyModel
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
if modelRequest.Model == "" {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
modelRequest.Model = "tts-1"
|
||||
} else {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
}
|
||||
// check token model mapping
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
if modelLimitEnable {
|
||||
@@ -128,8 +66,6 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
c.Set("group", userGroup)
|
||||
if shouldSelectChannel {
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
@@ -147,14 +83,87 @@ func Distribute() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
}
|
||||
}
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
var modelRequest ModelRequest
|
||||
shouldSelectChannel := true
|
||||
var err error
|
||||
if strings.Contains(c.Request.URL.Path, "/mj/") {
|
||||
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
|
||||
shouldSelectChannel = false
|
||||
} else {
|
||||
midjourneyRequest := dto.MidjourneyRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
|
||||
if err != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
|
||||
return nil, false, err
|
||||
}
|
||||
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
||||
if mjErr != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
|
||||
return nil, false, fmt.Errorf(mjErr.Description)
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
if !success {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
|
||||
} else {
|
||||
// task fetch, task fetch by condition, notify
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
}
|
||||
modelRequest.Model = midjourneyModel
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||
return nil, false, err
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
if modelRequest.Model == "" {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
modelRequest.Model = "tts-1"
|
||||
} else {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
}
|
||||
return &modelRequest, shouldSelectChannel, nil
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||
c.Set("original_model", modelName) // for retry
|
||||
if channel == nil {
|
||||
return
|
||||
}
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
@@ -163,12 +172,12 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
if nil != channel.OpenAIOrganization {
|
||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
}
|
||||
c.Set("auto_ban", ban)
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Set("original_model", modelName) // for retry
|
||||
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
// TODO: api_version统一
|
||||
|
||||
@@ -3,6 +3,8 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
@@ -27,8 +29,7 @@ func GetGroupModels(group string) []string {
|
||||
return models
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
var abilities []Ability
|
||||
func getPriority(group string, model string, retry int) (int, error) {
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
@@ -36,9 +37,55 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
var err error = nil
|
||||
var priorities []int
|
||||
err := DB.Model(&Ability{}).
|
||||
Select("DISTINCT(priority)").
|
||||
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||
|
||||
if err != nil {
|
||||
// 处理错误
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 确定要使用的优先级
|
||||
var priorityToUse int
|
||||
if retry >= len(priorities) {
|
||||
// 如果重试次数大于优先级数,则使用最小的优先级
|
||||
priorityToUse = priorities[len(priorities)-1]
|
||||
} else {
|
||||
priorityToUse = priorities[retry]
|
||||
}
|
||||
return priorityToUse, nil
|
||||
}
|
||||
|
||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
trueVal = "true"
|
||||
}
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
if retry != 0 {
|
||||
priority, err := getPriority(group, model, retry)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||
} else {
|
||||
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
|
||||
}
|
||||
}
|
||||
|
||||
return channelQuery
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
var abilities []Ability
|
||||
|
||||
var err error = nil
|
||||
channelQuery := getChannelQuery(group, model, retry)
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
||||
} else {
|
||||
@@ -57,7 +104,7 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
// Randomly choose one
|
||||
weight := common.GetRandomInt(int(weightSum))
|
||||
for _, ability_ := range abilities {
|
||||
weight -= int(ability_.Weight)
|
||||
weight -= int(ability_.Weight) + 10
|
||||
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
|
||||
if weight <= 0 {
|
||||
channel.Id = ability_.ChannelId
|
||||
@@ -88,7 +135,16 @@ func (channel *Channel) AddAbilities() error {
|
||||
abilities = append(abilities, ability)
|
||||
}
|
||||
}
|
||||
return DB.Create(&abilities).Error
|
||||
if len(abilities) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||
err := DB.Create(&chunk).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (channel *Channel) DeleteAbilities() error {
|
||||
|
||||
@@ -25,9 +25,6 @@ var token2UserId = make(map[string]int)
|
||||
var token2UserIdLock sync.RWMutex
|
||||
|
||||
func cacheSetToken(token *Token) error {
|
||||
if !common.RedisEnabled {
|
||||
return token.SelectUpdate()
|
||||
}
|
||||
jsonBytes, err := json.Marshal(token)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -168,7 +165,11 @@ func CacheUpdateUserQuota(id int) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
return cacheSetUserQuota(id, quota)
|
||||
}
|
||||
|
||||
func cacheSetUserQuota(id int, quota int) error {
|
||||
err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -204,6 +205,30 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
return userEnabled, err
|
||||
}
|
||||
|
||||
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
|
||||
if !common.RedisEnabled {
|
||||
return IsLinuxDoEnabled(userId)
|
||||
}
|
||||
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
|
||||
if err == nil {
|
||||
return enabled == "1", nil
|
||||
}
|
||||
|
||||
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
enabled = "0"
|
||||
if linuxDoEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set linuxdo enabled error: " + err.Error())
|
||||
}
|
||||
return linuxDoEnabled, err
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
@@ -272,7 +297,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
|
||||
|
||||
// if memory cache is disabled, get channel directly from database
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model)
|
||||
return GetRandomSatisfiedChannel(group, model, retry)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
|
||||
@@ -10,6 +10,7 @@ type Channel struct {
|
||||
Type int `json:"type" gorm:"default:0"`
|
||||
Key string `json:"key" gorm:"not null"`
|
||||
OpenAIOrganization *string `json:"openai_organization"`
|
||||
TestModel *string `json:"test_model"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index"`
|
||||
Weight *uint `json:"weight" gorm:"default:0"`
|
||||
@@ -24,8 +25,10 @@ type Channel struct {
|
||||
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
//MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
|
||||
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
}
|
||||
|
||||
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
||||
@@ -152,6 +155,13 @@ func (channel *Channel) GetModelMapping() string {
|
||||
return *channel.ModelMapping
|
||||
}
|
||||
|
||||
func (channel *Channel) GetStatusCodeMapping() string {
|
||||
if channel.StatusCodeMapping == nil {
|
||||
return ""
|
||||
}
|
||||
return *channel.StatusCodeMapping
|
||||
}
|
||||
|
||||
func (channel *Channel) Insert() error {
|
||||
var err error
|
||||
err = DB.Create(channel).Error
|
||||
|
||||
20
model/log.go
20
model/log.go
@@ -90,7 +90,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = DB
|
||||
@@ -115,11 +115,17 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
}
|
||||
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||
return logs, err
|
||||
return logs, total, err
|
||||
}
|
||||
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = DB.Where("user_id = ?", userId)
|
||||
@@ -138,8 +144,14 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
|
||||
return logs, err
|
||||
return logs, total, err
|
||||
}
|
||||
|
||||
func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
||||
|
||||
@@ -31,10 +31,12 @@ func InitOptionMap() {
|
||||
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
|
||||
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
|
||||
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
|
||||
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
|
||||
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
|
||||
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
|
||||
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
||||
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
|
||||
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
|
||||
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
||||
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
|
||||
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
||||
@@ -44,6 +46,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
|
||||
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
||||
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
|
||||
common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
|
||||
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
|
||||
common.OptionMap["SMTPServer"] = ""
|
||||
common.OptionMap["SMTPFrom"] = ""
|
||||
@@ -58,15 +61,18 @@ func InitOptionMap() {
|
||||
common.OptionMap["SystemName"] = common.SystemName
|
||||
common.OptionMap["Logo"] = common.Logo
|
||||
common.OptionMap["ServerAddress"] = ""
|
||||
common.OptionMap["PayAddress"] = ""
|
||||
common.OptionMap["CustomCallbackAddress"] = ""
|
||||
common.OptionMap["EpayId"] = ""
|
||||
common.OptionMap["EpayKey"] = ""
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(common.Price, 'f', -1, 64)
|
||||
common.OptionMap["StripeApiSecret"] = common.StripeApiSecret
|
||||
common.OptionMap["StripeWebhookSecret"] = common.StripeWebhookSecret
|
||||
common.OptionMap["StripePriceId"] = common.StripePriceId
|
||||
common.OptionMap["PaymentEnabled"] = strconv.FormatBool(common.PaymentEnabled)
|
||||
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["LinuxDoClientId"] = ""
|
||||
common.OptionMap["LinuxDoClientSecret"] = ""
|
||||
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
common.OptionMap["TelegramBotName"] = ""
|
||||
common.OptionMap["WeChatServerAddress"] = ""
|
||||
@@ -91,6 +97,9 @@ func InitOptionMap() {
|
||||
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
|
||||
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
|
||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
@@ -164,6 +173,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.EmailVerificationEnabled = boolValue
|
||||
case "GitHubOAuthEnabled":
|
||||
common.GitHubOAuthEnabled = boolValue
|
||||
case "LinuxDoOAuthEnabled":
|
||||
common.LinuxDoOAuthEnabled = boolValue
|
||||
case "WeChatAuthEnabled":
|
||||
common.WeChatAuthEnabled = boolValue
|
||||
case "TelegramOAuthEnabled":
|
||||
@@ -172,8 +183,12 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.TurnstileCheckEnabled = boolValue
|
||||
case "RegisterEnabled":
|
||||
common.RegisterEnabled = boolValue
|
||||
case "UserSelfDeletionEnabled":
|
||||
common.UserSelfDeletionEnabled = boolValue
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
common.EmailDomainRestrictionEnabled = boolValue
|
||||
case "EmailAliasRestrictionEnabled":
|
||||
common.EmailAliasRestrictionEnabled = boolValue
|
||||
case "AutomaticDisableChannelEnabled":
|
||||
common.AutomaticDisableChannelEnabled = boolValue
|
||||
case "AutomaticEnableChannelEnabled":
|
||||
@@ -192,6 +207,12 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.DefaultCollapseSidebar = boolValue
|
||||
case "MjNotifyEnabled":
|
||||
constant.MjNotifyEnabled = boolValue
|
||||
case "MjAccountFilterEnabled":
|
||||
constant.MjAccountFilterEnabled = boolValue
|
||||
case "MjModeClearEnabled":
|
||||
constant.MjModeClearEnabled = boolValue
|
||||
case "MjForwardUrlEnabled":
|
||||
constant.MjForwardUrlEnabled = boolValue
|
||||
case "CheckSensitiveEnabled":
|
||||
constant.CheckSensitiveEnabled = boolValue
|
||||
case "CheckSensitiveOnPromptEnabled":
|
||||
@@ -220,16 +241,16 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.SMTPToken = value
|
||||
case "ServerAddress":
|
||||
common.ServerAddress = value
|
||||
case "PayAddress":
|
||||
common.PayAddress = value
|
||||
case "CustomCallbackAddress":
|
||||
common.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
common.EpayId = value
|
||||
case "EpayKey":
|
||||
common.EpayKey = value
|
||||
case "Price":
|
||||
common.Price, _ = strconv.ParseFloat(value, 64)
|
||||
case "StripeApiSecret":
|
||||
common.StripeApiSecret = value
|
||||
case "StripeWebhookSecret":
|
||||
common.StripeWebhookSecret = value
|
||||
case "StripePriceId":
|
||||
common.StripePriceId = value
|
||||
case "PaymentEnabled":
|
||||
common.PaymentEnabled, _ = strconv.ParseBool(value)
|
||||
case "StripeUnitPrice":
|
||||
common.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "MinTopUp":
|
||||
common.MinTopUp, _ = strconv.Atoi(value)
|
||||
case "TopupGroupRatio":
|
||||
@@ -238,6 +259,12 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.GitHubClientId = value
|
||||
case "GitHubClientSecret":
|
||||
common.GitHubClientSecret = value
|
||||
case "LinuxDoClientId":
|
||||
common.LinuxDoClientId = value
|
||||
case "LinuxDoClientSecret":
|
||||
common.LinuxDoClientSecret = value
|
||||
case "LinuxDoMinLevel":
|
||||
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
|
||||
case "Footer":
|
||||
common.Footer = value
|
||||
case "SystemName":
|
||||
|
||||
@@ -56,7 +56,7 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
|
||||
common.RandomSleep()
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
|
||||
if err != nil {
|
||||
|
||||
@@ -102,6 +102,11 @@ func GetTokenById(id int) (*Token, error) {
|
||||
token := Token{Id: id}
|
||||
var err error = nil
|
||||
err = DB.First(&token, "id = ?", id).Error
|
||||
if err != nil {
|
||||
if common.RedisEnabled {
|
||||
go cacheSetToken(&token)
|
||||
}
|
||||
}
|
||||
return &token, err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
Status string `json:"status"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (topUp *TopUp) Insert() error {
|
||||
@@ -41,3 +49,51 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
|
||||
}
|
||||
return topUp
|
||||
}
|
||||
|
||||
func Recharge(referenceId string, customerId string) (err error) {
|
||||
if referenceId == "" {
|
||||
return errors.New("未提供支付单号")
|
||||
}
|
||||
|
||||
var quota float64
|
||||
topUp := &TopUp{}
|
||||
|
||||
refCol := "`trade_no`"
|
||||
if common.UsingPostgreSQL {
|
||||
refCol = `"trade_no"`
|
||||
}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
|
||||
if err != nil {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
return errors.New("充值订单状态错误")
|
||||
}
|
||||
|
||||
topUp.CompleteTime = common.GetTimestamp()
|
||||
topUp.Status = common.TopUpStatusSuccess
|
||||
err = tx.Save(topUp).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
quota = topUp.Money * common.QuotaPerUnit
|
||||
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return errors.New("充值失败," + err.Error())
|
||||
}
|
||||
|
||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.LogQuotaF(quota), topUp.Amount))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ type User struct {
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
|
||||
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
@@ -35,6 +37,7 @@ type User struct {
|
||||
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
|
||||
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
|
||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||
StripeCustomer string `json:"stripe_customer" gorm:"column:stripe_customer;index"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
@@ -64,7 +67,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||
|
||||
func GetMaxUserId() int {
|
||||
var user User
|
||||
DB.Last(&user)
|
||||
DB.Unscoped().Last(&user)
|
||||
return user.Id
|
||||
}
|
||||
|
||||
@@ -73,25 +76,34 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
|
||||
return users, err
|
||||
}
|
||||
|
||||
func SearchUsers(keyword string) ([]*User, error) {
|
||||
func SearchUsers(keyword string, group string) ([]*User, error) {
|
||||
var users []*User
|
||||
var err error
|
||||
|
||||
// 尝试将关键字转换为整数ID
|
||||
keywordInt, err := strconv.Atoi(keyword)
|
||||
if err == nil {
|
||||
// 如果转换成功,按照ID搜索用户
|
||||
err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error
|
||||
// 如果转换成功,按照ID和可选的组别搜索用户
|
||||
query := DB.Unscoped().Omit("password").Where("`id` = ?", keywordInt)
|
||||
if group != "" {
|
||||
query = query.Where("`group` = ?", group) // 使用反引号包围group
|
||||
}
|
||||
err = query.Find(&users).Error
|
||||
if err != nil || len(users) > 0 {
|
||||
// 如果依据ID找到用户或者发生错误,返回结果或错误
|
||||
return users, err
|
||||
}
|
||||
}
|
||||
|
||||
// 如果ID转换失败或者没有找到用户,依据其他字段进行模糊搜索
|
||||
err = DB.Unscoped().Omit("password").
|
||||
Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").
|
||||
Find(&users).Error
|
||||
err = nil
|
||||
|
||||
query := DB.Unscoped().Omit("password")
|
||||
likeCondition := "`username` LIKE ? OR `email` LIKE ? OR `display_name` LIKE ?"
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND `group` = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
err = query.Find(&users).Error
|
||||
|
||||
return users, err
|
||||
}
|
||||
@@ -110,6 +122,20 @@ func GetUserById(id int, selectAll bool) (*User, error) {
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func GetUserByIdUnscoped(id int, selectAll bool) (*User, error) {
|
||||
if id == 0 {
|
||||
return nil, errors.New("id 为空!")
|
||||
}
|
||||
user := User{Id: id}
|
||||
var err error = nil
|
||||
if selectAll {
|
||||
err = DB.Unscoped().First(&user, "id = ?", id).Error
|
||||
} else {
|
||||
err = DB.Unscoped().Omit("password").First(&user, "id = ?", id).Error
|
||||
}
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func GetUserIdByAffCode(affCode string) (int, error) {
|
||||
if affCode == "" {
|
||||
return 0, errors.New("affCode 为空!")
|
||||
@@ -235,7 +261,7 @@ func (user *User) Update(updatePassword bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (user *User) UpdateAll(updatePassword bool) error {
|
||||
func (user *User) Edit(updatePassword bool) error {
|
||||
var err error
|
||||
if updatePassword {
|
||||
user.Password, err = common.Password2Hash(user.Password)
|
||||
@@ -244,8 +270,17 @@ func (user *User) UpdateAll(updatePassword bool) error {
|
||||
}
|
||||
}
|
||||
newUser := *user
|
||||
updates := map[string]interface{}{
|
||||
"username": newUser.Username,
|
||||
"display_name": newUser.DisplayName,
|
||||
"group": newUser.Group,
|
||||
"quota": newUser.Quota,
|
||||
}
|
||||
if updatePassword {
|
||||
updates["password"] = newUser.Password
|
||||
}
|
||||
DB.First(&user, user.Id)
|
||||
err = DB.Model(user).Select("*").Updates(newUser).Error
|
||||
err = DB.Model(user).Updates(updates).Error
|
||||
if err == nil {
|
||||
if common.RedisEnabled {
|
||||
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
@@ -312,6 +347,14 @@ func (user *User) FillUserByGitHubId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByLinuxDoId() error {
|
||||
if user.LinuxDoId == "" {
|
||||
return errors.New("LINUX DO id 为空!")
|
||||
}
|
||||
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByWeChatId() error {
|
||||
if user.WeChatId == "" {
|
||||
return errors.New("WeChat id 为空!")
|
||||
@@ -351,6 +394,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
|
||||
return DB.Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsUsernameAlreadyTaken(username string) bool {
|
||||
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
@@ -396,6 +443,18 @@ func IsUserEnabled(userId int) (bool, error) {
|
||||
return user.Status == common.UserStatusEnabled, nil
|
||||
}
|
||||
|
||||
func IsLinuxDoEnabled(userId int) (bool, error) {
|
||||
if userId == 0 {
|
||||
return false, errors.New("user id is empty")
|
||||
}
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
|
||||
}
|
||||
|
||||
func ValidateAccessToken(token string) (user *User) {
|
||||
if token == "" {
|
||||
return nil
|
||||
@@ -410,6 +469,11 @@ func ValidateAccessToken(token string) (user *User) {
|
||||
|
||||
func GetUserQuota(id int) (quota int, err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
||||
if err != nil {
|
||||
if common.RedisEnabled {
|
||||
go cacheSetUserQuota(id, quota)
|
||||
}
|
||||
}
|
||||
return quota, err
|
||||
}
|
||||
|
||||
|
||||
@@ -136,7 +136,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
|
||||
|
||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aliResponse.Output.Text
|
||||
choice.Delta.SetContentString(aliResponse.Output.Text)
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
@@ -199,7 +199,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
response := streamResponseAli2OpenAI(&aliResponse)
|
||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
|
||||
lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
|
||||
79
relay/channel/aws/adaptor.go
Normal file
79
relay/channel/aws/adaptor.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
RequestModeCompletion = 1
|
||||
RequestModeMessage = 2
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||
a.RequestMode = RequestModeMessage
|
||||
} else {
|
||||
a.RequestMode = RequestModeCompletion
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
var claudeReq *claude.ClaudeRequest
|
||||
var err error
|
||||
if a.RequestMode == RequestModeCompletion {
|
||||
claudeReq = claude.RequestOpenAI2ClaudeComplete(*request)
|
||||
} else {
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
}
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", claudeReq)
|
||||
return claudeReq, err
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, info, a.RequestMode)
|
||||
} else {
|
||||
err, usage = awsHandler(c, info, a.RequestMode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() (models []string) {
|
||||
for n := range awsModelIDMap {
|
||||
models = append(models, n)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
12
relay/channel/aws/constants.go
Normal file
12
relay/channel/aws/constants.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package aws
|
||||
|
||||
var awsModelIDMap = map[string]string{
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
"claude-2.1": "anthropic.claude-v2:1",
|
||||
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
}
|
||||
|
||||
var ChannelName = "aws"
|
||||
15
relay/channel/aws/dto.go
Normal file
15
relay/channel/aws/dto.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package aws
|
||||
|
||||
import "one-api/relay/channel/claude"
|
||||
|
||||
type AwsClaudeRequest struct {
|
||||
// AnthropicVersion should be "bedrock-2023-05-31"
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
System string `json:"system"`
|
||||
Messages []claude.ClaudeMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
}
|
||||
213
relay/channel/aws/relay-aws.go
Normal file
213
relay/channel/aws/relay-aws.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
relaymodel "one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
)
|
||||
|
||||
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
||||
awsSecret := strings.Split(info.ApiKey, "|")
|
||||
if len(awsSecret) != 3 {
|
||||
return nil, errors.New("invalid aws secret key")
|
||||
}
|
||||
ak := awsSecret[0]
|
||||
sk := awsSecret[1]
|
||||
region := awsSecret[2]
|
||||
client := bedrockruntime.New(bedrockruntime.Options{
|
||||
Region: region,
|
||||
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
|
||||
})
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
|
||||
return &relaymodel.OpenAIErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: relaymodel.OpenAIError{
|
||||
Message: fmt.Sprintf("%s", err.Error()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) (string, error) {
|
||||
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelID, nil
|
||||
}
|
||||
|
||||
return "", errors.Errorf("model %s not found", requestModel)
|
||||
}
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
}
|
||||
|
||||
awsModelId, err := awsModelID(c.GetString("request_model"))
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*claude.ClaudeRequest)
|
||||
awsClaudeReq := &AwsClaudeRequest{
|
||||
AnthropicVersion: "bedrock-2023-05-31",
|
||||
}
|
||||
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
|
||||
return wrapErr(errors.Wrap(err, "copy request")), nil
|
||||
}
|
||||
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
|
||||
}
|
||||
|
||||
claudeResponse := new(claude.ClaudeResponse)
|
||||
err = json.Unmarshal(awsResp.Body, claudeResponse)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
|
||||
}
|
||||
|
||||
openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
|
||||
usage := relaymodel.Usage{
|
||||
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||
CompletionTokens: claudeResponse.Usage.OutputTokens,
|
||||
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
|
||||
}
|
||||
openaiResp.Usage = usage
|
||||
|
||||
c.JSON(http.StatusOK, openaiResp)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
}
|
||||
|
||||
awsModelId, err := awsModelID(c.GetString("request_model"))
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*claude.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := &AwsClaudeRequest{
|
||||
AnthropicVersion: "bedrock-2023-05-31",
|
||||
}
|
||||
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
|
||||
return wrapErr(errors.Wrap(err, "copy request")), nil
|
||||
}
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
|
||||
}
|
||||
stream := awsResp.GetStream()
|
||||
defer stream.Close()
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
var usage relaymodel.Usage
|
||||
var id string
|
||||
var model string
|
||||
createdTime := common.GetTimestamp()
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
event, ok := <-stream.Events()
|
||||
if !ok {
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
|
||||
switch v := event.(type) {
|
||||
case *types.ResponseStreamMemberChunk:
|
||||
claudeResp := new(claude.ClaudeResponse)
|
||||
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
|
||||
if claudeUsage != nil {
|
||||
usage.PromptTokens += claudeUsage.InputTokens
|
||||
usage.CompletionTokens += claudeUsage.OutputTokens
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if response.Id != "" {
|
||||
id = response.Id
|
||||
}
|
||||
if response.Model != "" {
|
||||
model = response.Model
|
||||
}
|
||||
response.Created = createdTime
|
||||
response.Id = id
|
||||
response.Model = model
|
||||
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case *types.UnknownUnionMember:
|
||||
fmt.Println("unknown tag:", v.Tag)
|
||||
return false
|
||||
default:
|
||||
fmt.Println("union is nil or unknown type")
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, &usage
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
|
||||
|
||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
choice.Delta.SetContentString(baiduResponse.Result)
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
|
||||
@@ -53,9 +53,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if a.RequestMode == RequestModeCompletion {
|
||||
return requestOpenAI2ClaudeComplete(*request), nil
|
||||
return RequestOpenAI2ClaudeComplete(*request), nil
|
||||
} else {
|
||||
return requestOpenAI2ClaudeMessage(*request)
|
||||
return RequestOpenAI2ClaudeMessage(*request)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,8 +28,8 @@ type ClaudeRequest struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
|
||||
@@ -20,25 +20,24 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "length"
|
||||
return "max_tokens"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
}
|
||||
|
||||
func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
MaxTokensToSample: textRequest.MaxTokens,
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
}
|
||||
if claudeRequest.MaxTokensToSample == 0 {
|
||||
claudeRequest.MaxTokensToSample = 1000000
|
||||
claudeRequest.MaxTokensToSample = 4096
|
||||
}
|
||||
prompt := ""
|
||||
for _, message := range textRequest.Messages {
|
||||
@@ -57,7 +56,7 @@ func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
|
||||
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
|
||||
claudeRequest := ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
@@ -70,10 +69,52 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = 4096
|
||||
}
|
||||
formatMessages := make([]dto.Message, 0)
|
||||
var lastMessage *dto.Message
|
||||
for i, message := range textRequest.Messages {
|
||||
//if message.Role == "system" {
|
||||
// if i != 0 {
|
||||
// message.Role = "user"
|
||||
// }
|
||||
//}
|
||||
if message.Role == "" {
|
||||
textRequest.Messages[i].Role = "user"
|
||||
}
|
||||
fmtMessage := dto.Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
}
|
||||
if lastMessage != nil && lastMessage.Role == message.Role {
|
||||
if lastMessage.IsStringContent() && message.IsStringContent() {
|
||||
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
|
||||
fmtMessage.Content = content
|
||||
// delete last message
|
||||
formatMessages = formatMessages[:len(formatMessages)-1]
|
||||
}
|
||||
}
|
||||
if fmtMessage.Content == nil {
|
||||
content, _ := json.Marshal("...")
|
||||
fmtMessage.Content = content
|
||||
}
|
||||
formatMessages = append(formatMessages, fmtMessage)
|
||||
lastMessage = &textRequest.Messages[i]
|
||||
}
|
||||
|
||||
claudeMessages := make([]ClaudeMessage, 0)
|
||||
for _, message := range textRequest.Messages {
|
||||
for _, message := range formatMessages {
|
||||
if message.Role == "system" {
|
||||
claudeRequest.System = message.StringContent()
|
||||
if message.IsStringContent() {
|
||||
claudeRequest.System = message.StringContent()
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
content := ""
|
||||
for _, ctx := range contents {
|
||||
if ctx.Type == "text" {
|
||||
content += ctx.Text
|
||||
}
|
||||
}
|
||||
claudeRequest.System = content
|
||||
}
|
||||
} else {
|
||||
claudeMessage := ClaudeMessage{
|
||||
Role: message.Role,
|
||||
@@ -118,11 +159,10 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
}
|
||||
claudeRequest.Prompt = ""
|
||||
claudeRequest.Messages = claudeMessages
|
||||
|
||||
return &claudeRequest, nil
|
||||
}
|
||||
|
||||
func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
|
||||
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
var claudeUsage *ClaudeUsage
|
||||
response.Object = "chat.completion.chunk"
|
||||
@@ -130,7 +170,7 @@ func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if reqMode == RequestModeCompletion {
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
choice.Delta.SetContentString(claudeResponse.Completion)
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
@@ -140,25 +180,34 @@ func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
||||
response.Id = claudeResponse.Message.Id
|
||||
response.Model = claudeResponse.Message.Model
|
||||
claudeUsage = &claudeResponse.Message.Usage
|
||||
choice.Delta.SetContentString("")
|
||||
choice.Delta.Role = "assistant"
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
return nil, nil
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
choice.Index = claudeResponse.Index
|
||||
choice.Delta.Content = claudeResponse.Delta.Text
|
||||
choice.Delta.SetContentString(claudeResponse.Delta.Text)
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
claudeUsage = &claudeResponse.Usage
|
||||
} else if claudeResponse.Type == "message_stop" {
|
||||
return nil, nil
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
if claudeUsage == nil {
|
||||
claudeUsage = &ClaudeUsage{}
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
|
||||
return &response, claudeUsage
|
||||
}
|
||||
|
||||
func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
||||
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
||||
choices := make([]dto.OpenAITextResponseChoice, 0)
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
@@ -242,7 +291,10 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
||||
return true
|
||||
}
|
||||
|
||||
response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
if response == nil {
|
||||
return true
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
responseText += claudeResponse.Completion
|
||||
responseId = response.Id
|
||||
@@ -317,7 +369,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||
|
||||
52
relay/channel/cohere/adaptor.go
Normal file
52
relay/channel/cohere/adaptor.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return requestOpenAI2Cohere(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
7
relay/channel/cohere/constant.go
Normal file
7
relay/channel/cohere/constant.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package cohere
|
||||
|
||||
var ModelList = []string{
|
||||
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
||||
}
|
||||
|
||||
var ChannelName = "cohere"
|
||||
44
relay/channel/cohere/dto.go
Normal file
44
relay/channel/cohere/dto.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package cohere
|
||||
|
||||
type CohereRequest struct {
|
||||
Model string `json:"model"`
|
||||
ChatHistory []ChatHistory `json:"chat_history"`
|
||||
Message string `json:"message"`
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int64 `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type ChatHistory struct {
|
||||
Role string `json:"role"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type CohereResponse struct {
|
||||
IsFinished bool `json:"is_finished"`
|
||||
EventType string `json:"event_type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Response *CohereResponseResult `json:"response"`
|
||||
}
|
||||
|
||||
type CohereResponseResult struct {
|
||||
ResponseId string `json:"response_id"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Text string `json:"text"`
|
||||
Meta CohereMeta `json:"meta"`
|
||||
}
|
||||
|
||||
type CohereMeta struct {
|
||||
//Tokens CohereTokens `json:"tokens"`
|
||||
BilledUnits CohereBilledUnits `json:"billed_units"`
|
||||
}
|
||||
|
||||
type CohereBilledUnits struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type CohereTokens struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
189
relay/channel/cohere/relay-cohere.go
Normal file
189
relay/channel/cohere/relay-cohere.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||
cohereReq := CohereRequest{
|
||||
Model: textRequest.Model,
|
||||
ChatHistory: []ChatHistory{},
|
||||
Message: "",
|
||||
Stream: textRequest.Stream,
|
||||
MaxTokens: textRequest.GetMaxTokens(),
|
||||
}
|
||||
if cohereReq.MaxTokens == 0 {
|
||||
cohereReq.MaxTokens = 4000
|
||||
}
|
||||
for _, msg := range textRequest.Messages {
|
||||
if msg.Role == "user" {
|
||||
cohereReq.Message = msg.StringContent()
|
||||
} else {
|
||||
var role string
|
||||
if msg.Role == "assistant" {
|
||||
role = "CHATBOT"
|
||||
} else if msg.Role == "system" {
|
||||
role = "SYSTEM"
|
||||
} else {
|
||||
role = "USER"
|
||||
}
|
||||
cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
|
||||
Role: role,
|
||||
Message: msg.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &cohereReq
|
||||
}
|
||||
|
||||
func stopReasonCohere2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "COMPLETE":
|
||||
return "stop"
|
||||
case "MAX_TOKENS":
|
||||
return "max_tokens"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
}
|
||||
|
||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var cohereResp CohereResponse
|
||||
err := json.Unmarshal([]byte(data), &cohereResp)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
var openaiResp dto.ChatCompletionsStreamResponse
|
||||
openaiResp.Id = responseId
|
||||
openaiResp.Created = createdTime
|
||||
openaiResp.Object = "chat.completion.chunk"
|
||||
openaiResp.Model = modelName
|
||||
if cohereResp.IsFinished {
|
||||
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
||||
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
||||
{
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
},
|
||||
}
|
||||
if cohereResp.Response != nil {
|
||||
usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
|
||||
usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
|
||||
}
|
||||
} else {
|
||||
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
||||
{
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Role: "assistant",
|
||||
Content: &cohereResp.Text,
|
||||
},
|
||||
Index: 0,
|
||||
},
|
||||
}
|
||||
responseText += cohereResp.Text
|
||||
}
|
||||
jsonStr, err := json.Marshal(openaiResp)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
createdTime := common.GetTimestamp()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var cohereResp CohereResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
||||
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
||||
|
||||
var openaiResp dto.TextResponse
|
||||
openaiResp.Id = cohereResp.ResponseId
|
||||
openaiResp.Created = createdTime
|
||||
openaiResp.Object = "chat.completion"
|
||||
openaiResp.Model = modelName
|
||||
openaiResp.Usage = usage
|
||||
|
||||
content, _ := json.Marshal(cohereResp.Text)
|
||||
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: dto.Message{Content: content, Role: "assistant"},
|
||||
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
|
||||
},
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(openaiResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
@@ -18,16 +18,28 @@ type Adaptor struct {
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
// 定义一个映射,存储模型名称和对应的版本
|
||||
var modelVersionMap = map[string]string{
|
||||
"gemini-1.5-pro-latest": "v1beta",
|
||||
"gemini-ultra": "v1beta",
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
version := "v1"
|
||||
if info.ApiVersion != "" {
|
||||
version = info.ApiVersion
|
||||
}
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
||||
version, beta := modelVersionMap[info.UpstreamModelName]
|
||||
if !beta {
|
||||
if info.ApiVersion != "" {
|
||||
version = info.ApiVersion
|
||||
} else {
|
||||
version = "v1"
|
||||
}
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
|
||||
@@ -5,8 +5,8 @@ const (
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
|
||||
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
|
||||
"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-ultra",
|
||||
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001",
|
||||
}
|
||||
|
||||
var ChannelName = "google gemini"
|
||||
|
||||
@@ -151,7 +151,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
choice.Delta.SetContentString(geminiResponse.GetResponseText())
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
@@ -203,7 +203,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
choice.Delta.SetContentString(dummy.Content)
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
|
||||
@@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||
@@ -41,7 +42,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
|
||||
return &OllamaEmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Prompt: request.Input,
|
||||
Prompt: strings.Join(request.ParseInput(), " "),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -72,8 +72,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
var toolCount int
|
||||
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
} else {
|
||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ var ModelList = []string{
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
|
||||
@@ -16,9 +16,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||
var responseTextBuilder strings.Builder
|
||||
toolCount := 0
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -49,7 +50,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
common.SafeSendString(dataChan, data)
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
streamItems = append(streamItems, data)
|
||||
@@ -67,14 +68,32 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||
if err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, streamResponse := range streamResponses {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
}
|
||||
for _, tool := range choice.Delta.ToolCalls {
|
||||
responseTextBuilder.WriteString(tool.Function.Name)
|
||||
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,7 +123,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
// wait data out
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
common.SafeSend(stopChan, true)
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}()
|
||||
service.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
@@ -123,10 +142,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
|
||||
}
|
||||
wg.Wait()
|
||||
return nil, responseTextBuilder.String()
|
||||
return nil, responseTextBuilder.String(), toolCount
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
|
||||
@@ -61,7 +61,7 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
choice.Delta.SetContentString(palmResponse.Candidates[0].Content)
|
||||
}
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
|
||||
@@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
|
||||
@@ -86,7 +86,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content)
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
@@ -138,7 +138,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
}
|
||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.Content
|
||||
responseText += response.Choices[0].Delta.GetContentString()
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
|
||||
@@ -87,7 +87,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCo
|
||||
}
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content)
|
||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
}
|
||||
@@ -179,7 +179,13 @@ func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId s
|
||||
case stop = <-stopChan:
|
||||
}
|
||||
}
|
||||
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
|
||||
@@ -126,7 +126,7 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
|
||||
|
||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = zhipuResponse
|
||||
choice.Delta.SetContentString(zhipuResponse)
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
@@ -138,7 +138,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStream
|
||||
|
||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = ""
|
||||
choice.Delta.SetContentString("")
|
||||
choice.FinishReason = &relaycommon.StopFinishReason
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: zhipuResponse.RequestId,
|
||||
|
||||
@@ -47,8 +47,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
var toolCount int
|
||||
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ const (
|
||||
APITypeZhipu_v4
|
||||
APITypeOllama
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -49,6 +51,10 @@ func ChannelType2APIType(channelType int) int {
|
||||
apiType = APITypeOllama
|
||||
case common.ChannelTypePerplexity:
|
||||
apiType = APITypePerplexity
|
||||
case common.ChannelTypeAws:
|
||||
apiType = APITypeAws
|
||||
case common.ChannelTypeCohere:
|
||||
apiType = APITypeCohere
|
||||
}
|
||||
return apiType
|
||||
}
|
||||
|
||||
@@ -20,15 +20,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var availableVoices = []string{
|
||||
"alloy",
|
||||
"echo",
|
||||
"fable",
|
||||
"onyx",
|
||||
"nova",
|
||||
"shimmer",
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
@@ -59,9 +50,6 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
if audioRequest.Voice == "" {
|
||||
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
if !common.StringsContains(availableVoices, audioRequest.Voice) {
|
||||
return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
var err error
|
||||
promptTokens := 0
|
||||
@@ -100,6 +88,22 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
}
|
||||
}
|
||||
|
||||
succeed := false
|
||||
defer func() {
|
||||
if succeed {
|
||||
return
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
// we need to roll back the pre-consumed quota
|
||||
defer func() {
|
||||
go func() {
|
||||
// negative means add quota back for token & user
|
||||
returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota)
|
||||
}()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" {
|
||||
@@ -163,6 +167,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return relaycommon.RelayErrorHandler(resp)
|
||||
}
|
||||
succeed = true
|
||||
|
||||
var audioResponse dto.AudioResponse
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
}
|
||||
|
||||
if imageRequest.Model == "" {
|
||||
imageRequest.Model = "dall-e-2"
|
||||
imageRequest.Model = "dall-e-3"
|
||||
}
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
@@ -186,7 +186,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
quality := "normal"
|
||||
if imageRequest.Quality == "hd" {
|
||||
quality = "hd"
|
||||
}
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelRatio, groupRatio, imageRequest.Size, quality)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
|
||||
@@ -110,11 +110,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
midjourneyTask.StartTime = originTask.StartTime
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" {
|
||||
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
|
||||
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
} else {
|
||||
midjourneyTask.ImageUrl = originTask.ImageUrl
|
||||
}
|
||||
midjourneyTask.Status = originTask.Status
|
||||
midjourneyTask.FailReason = originTask.FailReason
|
||||
|
||||
@@ -154,20 +154,28 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
return service.RelayErrorHandler(resp)
|
||||
if resp != nil {
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
|
||||
@@ -181,7 +189,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
checkSensitive := constant.ShouldCheckPromptSensitive()
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
|
||||
promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
|
||||
case relayconstant.RelayModeModerations:
|
||||
@@ -256,10 +264,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
completionTokens := usage.CompletionTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
|
||||
quota := 0
|
||||
if modelPrice == -1 {
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
@@ -271,7 +279,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
totalTokens := promptTokens + completionTokens
|
||||
var logContent string
|
||||
if modelPrice == -1 {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||
}
|
||||
@@ -288,11 +296,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
|
||||
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||
//}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
if quotaDelta != 0 {
|
||||
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
err := model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
|
||||
@@ -3,8 +3,10 @@ package relay
|
||||
import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ali"
|
||||
"one-api/relay/channel/aws"
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/cohere"
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/ollama"
|
||||
"one-api/relay/channel/openai"
|
||||
@@ -45,6 +47,10 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &ollama.Adaptor{}
|
||||
case constant.APITypePerplexity:
|
||||
return &perplexity.Adaptor{}
|
||||
case constant.APITypeAws:
|
||||
return &aws.Adaptor{}
|
||||
case constant.APITypeCohere:
|
||||
return &cohere.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -17,12 +17,13 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
|
||||
apiRouter.GET("/notice", controller.GetNotice)
|
||||
apiRouter.GET("/about", controller.GetAbout)
|
||||
//apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
|
||||
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxDoOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
|
||||
@@ -30,13 +31,14 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
|
||||
|
||||
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
|
||||
|
||||
userRoute := apiRouter.Group("/user")
|
||||
{
|
||||
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
|
||||
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
|
||||
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
|
||||
userRoute.GET("/logout", controller.Logout)
|
||||
userRoute.GET("/epay/notify", controller.EpayNotify)
|
||||
|
||||
selfRoute := userRoute.Group("/")
|
||||
selfRoute.Use(middleware.UserAuth())
|
||||
@@ -47,8 +49,8 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.DELETE("/self", controller.DeleteSelf)
|
||||
selfRoute.GET("/token", controller.GenerateAccessToken)
|
||||
selfRoute.GET("/aff", controller.GetAffCode)
|
||||
selfRoute.POST("/topup", controller.TopUp)
|
||||
selfRoute.POST("/pay", controller.RequestEpay)
|
||||
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
|
||||
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestPayLink)
|
||||
selfRoute.POST("/amount", controller.RequestAmount)
|
||||
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
||||
}
|
||||
@@ -62,7 +64,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
adminRoute.POST("/", controller.CreateUser)
|
||||
adminRoute.POST("/manage", controller.ManageUser)
|
||||
adminRoute.PUT("/", controller.UpdateUser)
|
||||
adminRoute.DELETE("/:id", controller.DeleteUser)
|
||||
adminRoute.DELETE("/:id", controller.HardDeleteUser)
|
||||
}
|
||||
}
|
||||
optionRoute := apiRouter.Group("/option")
|
||||
|
||||
@@ -57,6 +57,8 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
|
||||
return true
|
||||
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
|
||||
return true
|
||||
} else if strings.HasPrefix(err.Message, "You exceeded your current quota") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
package service
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
func GetCallbackAddress() string {
|
||||
if common.CustomCallbackAddress == "" {
|
||||
return common.ServerAddress
|
||||
}
|
||||
return common.CustomCallbackAddress
|
||||
}
|
||||
@@ -86,3 +86,22 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMappingStr string) {
|
||||
if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
|
||||
return
|
||||
}
|
||||
statusCodeMapping := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusOK {
|
||||
return
|
||||
}
|
||||
codeStr := strconv.Itoa(openaiErr.StatusCode)
|
||||
if _, ok := statusCodeMapping[codeStr]; ok {
|
||||
intCode, _ := strconv.Atoi(statusCodeMapping[codeStr])
|
||||
openaiErr.StatusCode = intCode
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,13 +165,24 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
delete(mapResult, "accountFilter")
|
||||
if !constant.MjAccountFilterEnabled {
|
||||
delete(mapResult, "accountFilter")
|
||||
}
|
||||
if !constant.MjNotifyEnabled {
|
||||
delete(mapResult, "notifyHook")
|
||||
}
|
||||
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
// make new request with mapResult
|
||||
}
|
||||
if constant.MjModeClearEnabled {
|
||||
if prompt, ok := mapResult["prompt"].(string); ok {
|
||||
prompt = strings.Replace(prompt, "--fast", "", -1)
|
||||
prompt = strings.Replace(prompt, "--relax", "", -1)
|
||||
prompt = strings.Replace(prompt, "--turbo", "", -1)
|
||||
|
||||
mapResult["prompt"] = prompt
|
||||
}
|
||||
}
|
||||
reqBody, err := json.Marshal(mapResult)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
|
||||
@@ -116,6 +116,41 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
||||
return tiles*170 + 85, nil
|
||||
}
|
||||
|
||||
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
|
||||
tkm := 0
|
||||
msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive)
|
||||
if err != nil {
|
||||
return 0, err, b
|
||||
}
|
||||
tkm += msgTokens
|
||||
if request.Tools != nil {
|
||||
toolsData, _ := json.Marshal(request.Tools)
|
||||
var openaiTools []dto.OpenAITools
|
||||
err := json.Unmarshal(toolsData, &openaiTools)
|
||||
if err != nil {
|
||||
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false
|
||||
}
|
||||
countStr := ""
|
||||
for _, tool := range openaiTools {
|
||||
countStr = tool.Function.Name
|
||||
if tool.Function.Description != "" {
|
||||
countStr += tool.Function.Description
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
}
|
||||
}
|
||||
toolTokens, err, _ := CountTokenInput(countStr, model, false)
|
||||
if err != nil {
|
||||
return 0, err, false
|
||||
}
|
||||
tkm += 8
|
||||
tkm += toolTokens
|
||||
}
|
||||
|
||||
return tkm, nil, false
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
|
||||
//recover when panic
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
@@ -138,48 +173,31 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
|
||||
tokenNum += tokensPerMessage
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if len(message.Content) > 0 {
|
||||
var arrayContent []dto.MediaMessage
|
||||
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
|
||||
return 0, err, false
|
||||
} else {
|
||||
if checkSensitive {
|
||||
contains, words := SensitiveWordContains(stringContent)
|
||||
if contains {
|
||||
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
||||
return 0, err, true
|
||||
}
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
if message.IsStringContent() {
|
||||
stringContent := message.StringContent()
|
||||
if checkSensitive {
|
||||
contains, words := SensitiveWordContains(stringContent)
|
||||
if contains {
|
||||
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
||||
return 0, err, true
|
||||
}
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
arrayContent := message.ParseContent()
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == "image_url" {
|
||||
var imageTokenNum int
|
||||
if model == "glm-4v" {
|
||||
imageTokenNum = 1047
|
||||
} else {
|
||||
if str, ok := m.ImageUrl.(string); ok {
|
||||
imageTokenNum, err = getImageToken(&dto.MessageImageUrl{Url: str, Detail: "auto"})
|
||||
} else {
|
||||
imageUrlMap := m.ImageUrl.(map[string]interface{})
|
||||
detail, ok := imageUrlMap["detail"]
|
||||
if ok {
|
||||
imageUrlMap["detail"] = detail.(string)
|
||||
} else {
|
||||
imageUrlMap["detail"] = "auto"
|
||||
}
|
||||
imageUrl := dto.MessageImageUrl{
|
||||
Url: imageUrlMap["url"].(string),
|
||||
Detail: imageUrlMap["detail"].(string),
|
||||
}
|
||||
imageTokenNum, err = getImageToken(&imageUrl)
|
||||
}
|
||||
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
||||
imageTokenNum, err = getImageToken(&imageUrl)
|
||||
if err != nil {
|
||||
return 0, err, false
|
||||
}
|
||||
@@ -211,6 +229,23 @@ func CountTokenInput(input any, model string, check bool) (int, error, bool) {
|
||||
return CountTokenInput(fmt.Sprintf("%v", input), model, check)
|
||||
}
|
||||
|
||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||
tokens := 0
|
||||
for _, message := range messages {
|
||||
tkm, _, _ := CountTokenInput(message.Delta.GetContentString(), model, false)
|
||||
tokens += tkm
|
||||
if message.Delta.ToolCalls != nil {
|
||||
for _, tool := range message.Delta.ToolCalls {
|
||||
tkm, _, _ := CountTokenInput(tool.Function.Name, model, false)
|
||||
tokens += tkm
|
||||
tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false)
|
||||
tokens += tkm
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountAudioToken(text string, model string, check bool) (int, error, bool) {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
contains, words := SensitiveWordContains(text)
|
||||
|
||||
5
web/.gitignore
vendored
5
web/.gitignore
vendored
@@ -10,6 +10,7 @@
|
||||
|
||||
# production
|
||||
/build
|
||||
/dist
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
@@ -21,6 +22,4 @@
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.idea
|
||||
package-lock.json
|
||||
yarn.lock
|
||||
.idea/
|
||||
@@ -1 +1 @@
|
||||
module.exports = require("@so1ve/prettier-config");
|
||||
module.exports = require('@so1ve/prettier-config');
|
||||
|
||||
@@ -11,6 +11,7 @@ import EditUser from './pages/User/EditUser';
|
||||
import { getLogo, getSystemName } from './helpers';
|
||||
import PasswordResetForm from './components/PasswordResetForm';
|
||||
import GitHubOAuth from './components/GitHubOAuth';
|
||||
import LinuxDoOAuth from './components/LinuxDoOAuth';
|
||||
import PasswordResetConfirm from './components/PasswordResetConfirm';
|
||||
import { UserContext } from './context/User';
|
||||
import Channel from './pages/Channel';
|
||||
@@ -171,6 +172,14 @@ function App() {
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/linuxdo'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<LinuxDoOAuth />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/setting'
|
||||
element={
|
||||
|
||||
@@ -254,6 +254,19 @@ const ChannelsTable = () => {
|
||||
>
|
||||
编辑
|
||||
</Button>
|
||||
<Popconfirm
|
||||
title='确定是否要复制此渠道?'
|
||||
content='复制渠道的所有信息'
|
||||
okType={'danger'}
|
||||
position={'left'}
|
||||
onConfirm={async () => {
|
||||
copySelectedChannel(record.id);
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='primary' style={{ marginRight: 1 }}>
|
||||
复制
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
</div>
|
||||
),
|
||||
},
|
||||
@@ -340,6 +353,33 @@ const ChannelsTable = () => {
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const copySelectedChannel = async (id) => {
|
||||
const channelToCopy = channels.find(
|
||||
(channel) => String(channel.id) === String(id),
|
||||
);
|
||||
console.log(channelToCopy);
|
||||
channelToCopy.name += '_复制';
|
||||
channelToCopy.created_time = null;
|
||||
channelToCopy.balance = 0;
|
||||
channelToCopy.used_quota = 0;
|
||||
if (!channelToCopy) {
|
||||
showError('渠道未找到,请刷新页面后重试。');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const newChannel = { ...channelToCopy, id: undefined };
|
||||
const response = await API.post('/api/channel/', newChannel);
|
||||
if (response.data.success) {
|
||||
showSuccess('渠道复制成功');
|
||||
await refresh();
|
||||
} else {
|
||||
showError(response.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError('渠道复制失败: ' + error.message);
|
||||
}
|
||||
};
|
||||
|
||||
const refresh = async () => {
|
||||
await loadChannels(activePage - 1, pageSize, idSort);
|
||||
};
|
||||
|
||||
@@ -14,9 +14,14 @@ const GitHubOAuth = () => {
|
||||
let navigate = useNavigate();
|
||||
|
||||
const sendCode = async (code, state, count) => {
|
||||
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
|
||||
let aff = localStorage.getItem('aff');
|
||||
const res = await API.get(
|
||||
`/api/oauth/github?code=${code}&state=${state}&aff=${aff}`,
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
localStorage.removeItem('aff');
|
||||
|
||||
if (message === 'bind') {
|
||||
showSuccess('绑定成功!');
|
||||
navigate('/setting');
|
||||
@@ -41,6 +46,14 @@ const GitHubOAuth = () => {
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let error = searchParams.get('error');
|
||||
if (error) {
|
||||
let errorDescription = searchParams.get('error_description');
|
||||
showError(`授权错误:${error}: ${errorDescription}`);
|
||||
navigate('/setting');
|
||||
return;
|
||||
}
|
||||
|
||||
let code = searchParams.get('code');
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { Link, useNavigate } from 'react-router-dom';
|
||||
import { UserContext } from '../context/User';
|
||||
import { useSetTheme, useTheme } from '../context/Theme';
|
||||
|
||||
import { API, getLogo, getSystemName, showSuccess } from '../helpers';
|
||||
import '../index.css';
|
||||
@@ -34,10 +35,8 @@ const HeaderBar = () => {
|
||||
let navigate = useNavigate();
|
||||
|
||||
const [showSidebar, setShowSidebar] = useState(false);
|
||||
const [dark, setDark] = useState(false);
|
||||
const systemName = getSystemName();
|
||||
const logo = getLogo();
|
||||
var themeMode = localStorage.getItem('theme-mode');
|
||||
const currentDate = new Date();
|
||||
// enable fireworks on new year(1.1 and 2.9-2.24)
|
||||
const isNewYear =
|
||||
@@ -66,26 +65,19 @@ const HeaderBar = () => {
|
||||
}, 3000);
|
||||
};
|
||||
|
||||
const theme = useTheme();
|
||||
const setTheme = useSetTheme();
|
||||
|
||||
useEffect(() => {
|
||||
if (themeMode === 'dark') {
|
||||
switchMode(true);
|
||||
if (theme === 'dark') {
|
||||
document.body.setAttribute('theme-mode', 'dark');
|
||||
}
|
||||
|
||||
if (isNewYear) {
|
||||
console.log('Happy New Year!');
|
||||
}
|
||||
}, []);
|
||||
|
||||
const switchMode = (model) => {
|
||||
const body = document.body;
|
||||
if (!model) {
|
||||
body.removeAttribute('theme-mode');
|
||||
localStorage.setItem('theme-mode', 'light');
|
||||
} else {
|
||||
body.setAttribute('theme-mode', 'dark');
|
||||
localStorage.setItem('theme-mode', 'dark');
|
||||
}
|
||||
setDark(model);
|
||||
};
|
||||
return (
|
||||
<>
|
||||
<Layout>
|
||||
@@ -132,9 +124,11 @@ const HeaderBar = () => {
|
||||
<Switch
|
||||
checkedText='🌞'
|
||||
size={'large'}
|
||||
checked={dark}
|
||||
checked={theme === 'dark'}
|
||||
uncheckedText='🌙'
|
||||
onChange={switchMode}
|
||||
onChange={(checked) => {
|
||||
setTheme(checked);
|
||||
}}
|
||||
/>
|
||||
{userState.user ? (
|
||||
<>
|
||||
|
||||
27
web/src/components/LinuxDoIcon.js
Normal file
27
web/src/components/LinuxDoIcon.js
Normal file
@@ -0,0 +1,27 @@
|
||||
import React from 'react';
|
||||
import { Icon } from '@douyinfe/semi-ui';
|
||||
|
||||
const LinuxDoIcon = (props) => {
|
||||
function CustomIcon() {
|
||||
return (
|
||||
<svg
|
||||
className='icon'
|
||||
viewBox='0 0 24 24'
|
||||
version='1.1'
|
||||
xmlns='http://www.w3.org/2000/svg'
|
||||
width='1em'
|
||||
height='1em'
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d='M19.7,17.6c-0.1-0.2-0.2-0.4-0.2-0.6c0-0.4-0.2-0.7-0.5-1c-0.1-0.1-0.3-0.2-0.4-0.2c0.6-1.8-0.3-3.6-1.3-4.9c0,0,0,0,0,0c-0.8-1.2-2-2.1-1.9-3.7c0-1.9,0.2-5.4-3.3-5.1C8.5,2.3,9.5,6,9.4,7.3c0,1.1-0.5,2.2-1.3,3.1c-0.2,0.2-0.4,0.5-0.5,0.7c-1,1.2-1.5,2.8-1.5,4.3c-0.2,0.2-0.4,0.4-0.5,0.6c-0.1,0.1-0.2,0.2-0.2,0.3c-0.1,0.1-0.3,0.2-0.5,0.3c-0.4,0.1-0.7,0.3-0.9,0.7c-0.1,0.3-0.2,0.7-0.1,1.1c0.1,0.2,0.1,0.4,0,0.7c-0.2,0.4-0.2,0.9,0,1.4c0.3,0.4,0.8,0.5,1.5,0.6c0.5,0,1.1,0.2,1.6,0.4l0,0c0.5,0.3,1.1,0.5,1.7,0.5c0.3,0,0.7-0.1,1-0.2c0.3-0.2,0.5-0.4,0.6-0.7c0.4,0,1-0.2,1.7-0.2c0.6,0,1.2,0.2,2,0.1c0,0.1,0,0.2,0.1,0.3c0.2,0.5,0.7,0.9,1.3,1c0.1,0,0.1,0,0.2,0c0.8-0.1,1.6-0.5,2.1-1.1l0,0c0.4-0.4,0.9-0.7,1.4-0.9c0.6-0.3,1-0.5,1.1-1C20.3,18.6,20.1,18.2,19.7,17.6z M12.8,4.8c0.6,0.1,1.1,0.6,1,1.2c0,0.3-0.1,0.6-0.3,0.9c0,0,0,0-0.1,0c-0.2-0.1-0.3-0.1-0.4-0.2c0.1-0.1,0.1-0.3,0.2-0.5c0-0.4-0.2-0.7-0.4-0.7c-0.3,0-0.5,0.3-0.5,0.7c0,0,0,0.1,0,0.1c-0.1-0.1-0.3-0.1-0.4-0.2c0,0,0-0.1,0-0.1C11.8,5.5,12.2,4.9,12.8,4.8z M12.5,6.8c0.1,0.1,0.3,0.2,0.4,0.2c0.1,0,0.3,0.1,0.4,0.2c0.2,0.1,0.4,0.2,0.4,0.5c0,0.3-0.3,0.6-0.9,0.8c-0.2,0.1-0.3,0.1-0.4,0.2c-0.3,0.2-0.6,0.3-1,0.3c-0.3,0-0.6-0.2-0.8-0.4c-0.1-0.1-0.2-0.2-0.4-0.3C10.1,8.2,9.9,8,9.8,7.7c0-0.1,0.1-0.2,0.2-0.3c0.3-0.2,0.4-0.3,0.5-0.4l0.1-0.1c0.2-0.3,0.6-0.5,1-0.5C11.9,6.5,12.2,6.6,12.5,6.8z M10.4,5c0.4,0,0.7,0.4,0.8,1.1c0,0.1,0,0.1,0,0.2c-0.1,0-0.3,0.1-0.4,0.2c0,0,0-0.1,0-0.2c0-0.3-0.2-0.6-0.4-0.5c-0.2,0-0.3,0.3-0.3,0.6c0,0.2,0.1,0.3,0.2,0.4l0,0c0,0-0.1,0.1-0.2,0.1C9.9,6.7,9.7,6.4,9.7,6.1C9.7,5.5,10,5,10.4,5z M9.4,21.1c-0.7,0.3-1.6,0.2-2.2-0.2c-0.6-0.3-1.1-0.4-1.8-0.4c-0.5-0.1-1-0.1-1.1-0.3c-0.1-0.2-0.1-0.5,0.1-1c0.1-0.3,0.1-0.6,0-0.9c-0.1-0.3-0.1-0.5,0-0.8C4.5,17.2,4.7,17.1,5,17c0.3-0.1,0.5-0.2,0.7-0.4c0.1-0.1,0.2-0.2,0.3-0.4c0.3-0.4,0.5-0.6,0.8-0.6c0.6,0.1,1.1,1,1.5,1.9c0.2,0.3,0.4,0.7,0.7,1c0.4,0.5,0.9,1.2,0.9,1.6C9.9,20.6,9.7,20.9,9.4,21.1z M14.3,18.9c0,0.1,0,0.1-0.1,0.2c-1.2,0.9-2.8,1-4.1,0.3c-0.2-0.3-0.4-0.6-0.6-0.9c0.9-0.1,0.7-1.3-1.2-2.5c-2-1.3-0.6-3.7,0.1-4.8c0.1-0.1,0.1,0-0.3,0.8c-0.3,0.6-0.9,2.1-0.1,3.2c0-0.8,0.2-1.6,0.5-2.4c0.7-1.3,1.2-2.8,1.5-4.3c0.1,0.1,0.1,0.1,0.2,0.1c0.1,0.1,0.2,0.2,0.3,0.2c0.2,0.3,0.6,0.4,0.9,0.4c0,0,0.1,0,0.1,0c0.4,0,0.8-0.1,1.1-0.4c0.1-0.1,0.2-0.2,0.4-0.2c0.3-0.1,0.6-0.3,0.9-0.6c0.4,1.3,0.8,2.5,1.4,3.6c0.4,0.8,0.7,1.6,0.9,2.5c0.3,0,0.7,0.1,1,0.3c0.8,0.4,1.1,0.7,1,1.2c-0.1,0-0.1,0-0.2,0c0-0.3-0.2-0.6-0.9-0.9c-0.7-0.3-1.3-0.3-1.5,0.4c-0.1,0-0.2,0.1-0.3,0.1c-0.8,0.4-0.8,1.5-0.9,2.6C14.5,18.2,14.4,18.5,14.3,18.9z M18.9,19.5c-0.6,0.2-1.1,0.6-1.5,1.1c-0.4,0.6-1.1,1-1.9,0.9c-0.4,0-0.8-0.3-0.9-0.7c-0.1-0.6-0.1-1.2,0.2-1.8c0.1-0.4,0.2-0.7,0.3-1.1c0.1-1.2,0.1-1.9,0.6-2.2h0c0,0.5,0.3,0.8,0.7,1c0.5,0,1-0.1,1.4-0.5c0.1,0,0.1,0,0.2,0c0.3,0,0.5,0,0.7,0.2c0.2,0.2,0.3,0.5,0.3,0.7c0,0.3,0.2,0.6,0.3,0.9c0.5,0.5,0.5,0.8,0.5,0.9C19.7,19.1,19.3,19.3,18.9,19.5z M9.9,7.5c-0.1,0-0.1,0-0.1,0.1c0,0,0,0.1,0.1,0.1c0,0,0,0,0,0c0.1,0,0.1,0.1,0.1,0.1c0.3,0.4,0.8,0.6,1.4,0.7c0.5-0.1,1-0.2,1.5-0.6c0.2-0.1,0.4-0.2,0.6-0.3c0.1,0,0.1-0.1,0.1-0.1c0-0.1,0-0.1-0.1-0.1l0,0c-0.2,0.1-0.5,0.2-0.7,0.3c-0.4,0.3-0.9,0.5-1.4,0.5c-0.5,0-0.9-0.3-1.2-0.6C10.1,7.6,10,7.5,9.9,7.5z'
|
||||
fill='currentColor'
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
return <Icon svg={<CustomIcon />} />;
|
||||
};
|
||||
|
||||
export default LinuxDoIcon;
|
||||
71
web/src/components/LinuxDoOAuth.js
Normal file
71
web/src/components/LinuxDoOAuth.js
Normal file
@@ -0,0 +1,71 @@
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { API, showError, showSuccess } from '../helpers';
|
||||
import { UserContext } from '../context/User';
|
||||
|
||||
const LinuxDoOAuth = () => {
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
const [prompt, setPrompt] = useState('处理中...');
|
||||
const [processing, setProcessing] = useState(true);
|
||||
|
||||
let navigate = useNavigate();
|
||||
|
||||
const sendCode = async (code, state, count) => {
|
||||
let aff = localStorage.getItem('aff');
|
||||
const res = await API.get(
|
||||
`/api/oauth/linuxdo?code=${code}&state=${state}&aff=${aff}`,
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
localStorage.removeItem('aff');
|
||||
|
||||
if (message === 'bind') {
|
||||
showSuccess('绑定成功!');
|
||||
navigate('/setting');
|
||||
} else {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
showSuccess('登录成功!');
|
||||
navigate('/');
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
if (count === 0) {
|
||||
setPrompt(`操作失败,重定向至登录界面中...`);
|
||||
navigate('/setting'); // in case this is failed to bind GitHub
|
||||
return;
|
||||
}
|
||||
count++;
|
||||
setPrompt(`出现错误,第 ${count} 次重试中...`);
|
||||
await new Promise((resolve) => setTimeout(resolve, count * 2000));
|
||||
await sendCode(code, state, count);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let error = searchParams.get('error');
|
||||
if (error) {
|
||||
let errorDescription = searchParams.get('error_description');
|
||||
showError(`授权错误:${error}: ${errorDescription}`);
|
||||
navigate('/setting');
|
||||
return;
|
||||
}
|
||||
|
||||
let code = searchParams.get('code');
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Segment style={{ minHeight: '300px' }}>
|
||||
<Dimmer active inverted>
|
||||
<Loader size='large'>{prompt}</Loader>
|
||||
</Dimmer>
|
||||
</Segment>
|
||||
);
|
||||
};
|
||||
|
||||
export default LinuxDoOAuth;
|
||||
@@ -2,7 +2,7 @@ import React, { useContext, useEffect, useState } from 'react';
|
||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { UserContext } from '../context/User';
|
||||
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
|
||||
import { onGitHubOAuthClicked } from './utils';
|
||||
import { onGitHubOAuthClicked, onLinuxDoOAuthClicked } from './utils';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import {
|
||||
Button,
|
||||
@@ -18,6 +18,7 @@ import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
||||
import TelegramLoginButton from 'react-telegram-login';
|
||||
|
||||
import { IconGithubLogo } from '@douyinfe/semi-icons';
|
||||
import LinuxDoIcon from './LinuxDoIcon';
|
||||
import WeChatIcon from './WeChatIcon';
|
||||
|
||||
const LoginForm = () => {
|
||||
@@ -231,10 +232,25 @@ const LoginForm = () => {
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
{status.linuxdo_oauth ? (
|
||||
<Button
|
||||
type='primary'
|
||||
icon={<LinuxDoIcon />}
|
||||
style={{ color: '#000', margin: '0 5px' }}
|
||||
onClick={() =>
|
||||
onLinuxDoOAuthClicked(status.linuxdo_client_id)
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
{status.wechat_login ? (
|
||||
<Button
|
||||
type='primary'
|
||||
style={{ color: 'rgba(var(--semi-green-5), 1)' }}
|
||||
style={{
|
||||
color: 'rgba(var(--semi-green-5), 1)',
|
||||
margin: '0 5px',
|
||||
}}
|
||||
icon={<Icon svg={<WeChatIcon />} />}
|
||||
onClick={onWeChatLoginClicked}
|
||||
/>
|
||||
@@ -244,6 +260,8 @@ const LoginForm = () => {
|
||||
|
||||
{status.telegram_oauth ? (
|
||||
<TelegramLoginButton
|
||||
className='semi-button semi-button-with-icon semi-button-with-icon-only'
|
||||
buttonSize='medium'
|
||||
dataOnauth={onTelegramLoginClicked}
|
||||
botName={status.telegram_bot_name}
|
||||
/>
|
||||
|
||||
@@ -312,7 +312,7 @@ const LogsTable = () => {
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [loadingStat, setLoadingStat] = useState(false);
|
||||
const [activePage, setActivePage] = useState(1);
|
||||
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
|
||||
const [logCount, setLogCount] = useState(0);
|
||||
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
const [searching, setSearching] = useState(false);
|
||||
@@ -409,14 +409,14 @@ const LogsTable = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const setLogsFormat = (logs) => {
|
||||
const setLogsFormat = (logs, total) => {
|
||||
for (let i = 0; i < logs.length; i++) {
|
||||
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||
logs[i].key = '' + logs[i].id;
|
||||
}
|
||||
// data.key = '' + data.id
|
||||
setLogs(logs);
|
||||
setLogCount(logs.length + ITEMS_PER_PAGE);
|
||||
setLogCount(total);
|
||||
// console.log(logCount);
|
||||
};
|
||||
|
||||
@@ -432,14 +432,14 @@ const LogsTable = () => {
|
||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
}
|
||||
const res = await API.get(url);
|
||||
const { success, message, data } = res.data;
|
||||
const { success, message, total, data } = res.data;
|
||||
if (success) {
|
||||
if (startIdx === 0) {
|
||||
setLogsFormat(data);
|
||||
setLogsFormat(data, total);
|
||||
} else {
|
||||
let newLogs = [...logs];
|
||||
newLogs.splice(startIdx * pageSize, data.length, ...data);
|
||||
setLogsFormat(newLogs);
|
||||
setLogsFormat(newLogs, total);
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
@@ -606,7 +606,9 @@ const LogsTable = () => {
|
||||
type='primary'
|
||||
htmlType='submit'
|
||||
className='btn-margin-right'
|
||||
onClick={refresh}
|
||||
onClick={() => {
|
||||
refresh(logType).then();
|
||||
}}
|
||||
loading={loading}
|
||||
>
|
||||
查询
|
||||
|
||||
@@ -236,6 +236,31 @@ const renderTimestamp = (timestampInSeconds) => {
|
||||
|
||||
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
|
||||
};
|
||||
// 修改renderDuration函数以包含颜色逻辑
|
||||
function renderDuration(submit_time, finishTime) {
|
||||
// 确保startTime和finishTime都是有效的时间戳
|
||||
if (!submit_time || !finishTime) return 'N/A';
|
||||
|
||||
// 将时间戳转换为Date对象
|
||||
const start = new Date(submit_time);
|
||||
const finish = new Date(finishTime);
|
||||
|
||||
// 计算时间差(毫秒)
|
||||
const durationMs = finish - start;
|
||||
|
||||
// 将时间差转换为秒,并保留一位小数
|
||||
const durationSec = (durationMs / 1000).toFixed(1);
|
||||
|
||||
// 设置颜色:大于60秒则为红色,小于等于60秒则为绿色
|
||||
const color = durationSec > 60 ? 'red' : 'green';
|
||||
|
||||
// 返回带有样式的颜色标签
|
||||
return (
|
||||
<Tag color={color} size='large'>
|
||||
{durationSec} 秒
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
const LogsTable = () => {
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
@@ -248,6 +273,15 @@ const LogsTable = () => {
|
||||
return <div>{renderTimestamp(text / 1000)}</div>;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '花费时间',
|
||||
dataIndex: 'finish_time', // 以finish_time作为dataIndex
|
||||
key: 'finish_time',
|
||||
render: (finish, record) => {
|
||||
// 假设record.start_time是存在的,并且finish是完成时间的时间戳
|
||||
return renderDuration(record.submit_time, finish);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '渠道',
|
||||
dataIndex: 'channel_id',
|
||||
|
||||
@@ -8,6 +8,8 @@ import {
|
||||
verifyJSON,
|
||||
} from '../helpers';
|
||||
|
||||
import { useTheme } from '../context/Theme';
|
||||
|
||||
const OperationSetting = () => {
|
||||
let now = new Date();
|
||||
let [inputs, setInputs] = useState({
|
||||
@@ -36,6 +38,9 @@ const OperationSetting = () => {
|
||||
StopOnSensitiveEnabled: '',
|
||||
SensitiveWords: '',
|
||||
MjNotifyEnabled: '',
|
||||
MjAccountFilterEnabled: '',
|
||||
MjModeClearEnabled: '',
|
||||
MjForwardUrlEnabled: '',
|
||||
DrawingEnabled: '',
|
||||
DataExportEnabled: '',
|
||||
DataExportDefaultTime: 'hour',
|
||||
@@ -76,6 +81,9 @@ const OperationSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const theme = useTheme();
|
||||
const isDark = theme === 'dark';
|
||||
|
||||
useEffect(() => {
|
||||
getOptions().then();
|
||||
}, []);
|
||||
@@ -218,8 +226,10 @@ const OperationSetting = () => {
|
||||
return (
|
||||
<Grid columns={1}>
|
||||
<Grid.Column>
|
||||
<Form loading={loading}>
|
||||
<Header as='h3'>通用设置</Header>
|
||||
<Form loading={loading} inverted={isDark}>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
通用设置
|
||||
</Header>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input
|
||||
label='充值链接'
|
||||
@@ -298,7 +308,9 @@ const OperationSetting = () => {
|
||||
保存通用设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>绘图设置</Header>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
绘图设置
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DrawingEnabled === 'true'}
|
||||
@@ -312,9 +324,29 @@ const OperationSetting = () => {
|
||||
name='MjNotifyEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.MjAccountFilterEnabled === 'true'}
|
||||
label='允许AccountFilter参数'
|
||||
name='MjAccountFilterEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.MjForwardUrlEnabled === 'true'}
|
||||
label='开启之后将上游地址替换为服务器地址'
|
||||
name='MjForwardUrlEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.MjModeClearEnabled === 'true'}
|
||||
label='开启之后会清除用户提示词中的--fast、--relax以及--turbo参数'
|
||||
name='MjModeClearEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider />
|
||||
<Header as='h3'>屏蔽词过滤设置</Header>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
屏蔽词过滤设置
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.CheckSensitiveEnabled === 'true'}
|
||||
@@ -374,7 +406,9 @@ const OperationSetting = () => {
|
||||
保存屏蔽词设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>日志设置</Header>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
日志设置
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.LogConsumeEnabled === 'true'}
|
||||
@@ -402,7 +436,9 @@ const OperationSetting = () => {
|
||||
清理历史日志
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>数据看板</Header>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
数据看板
|
||||
</Header>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DataExportEnabled === 'true'}
|
||||
label='启用数据看板(实验性)'
|
||||
@@ -432,7 +468,9 @@ const OperationSetting = () => {
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider />
|
||||
<Header as='h3'>监控设置</Header>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
监控设置
|
||||
</Header>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Input
|
||||
label='最长响应时间'
|
||||
@@ -477,7 +515,9 @@ const OperationSetting = () => {
|
||||
保存监控设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>额度设置</Header>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
额度设置
|
||||
</Header>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input
|
||||
label='新用户初始额度'
|
||||
@@ -528,7 +568,9 @@ const OperationSetting = () => {
|
||||
保存额度设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>倍率设置</Header>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
倍率设置
|
||||
</Header>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)'
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
} from '../helpers';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import { UserContext } from '../context/User';
|
||||
import { onGitHubOAuthClicked } from './utils';
|
||||
import { onGitHubOAuthClicked, onLinuxDoOAuthClicked } from './utils';
|
||||
import {
|
||||
Avatar,
|
||||
Banner,
|
||||
@@ -519,6 +519,39 @@ const PersonalSetting = () => {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>LINUX DO</Typography.Text>
|
||||
<div
|
||||
style={{ display: 'flex', justifyContent: 'space-between' }}
|
||||
>
|
||||
<div>
|
||||
<Input
|
||||
value={
|
||||
userState.user && userState.user.linuxdo_id !== ''
|
||||
? userState.user.linuxdo_id +
|
||||
'(' +
|
||||
userState.user.linuxdo_level +
|
||||
'级)'
|
||||
: '未绑定'
|
||||
}
|
||||
readonly={true}
|
||||
></Input>
|
||||
</div>
|
||||
<div>
|
||||
<Button
|
||||
onClick={() => {
|
||||
onLinuxDoOAuthClicked(status.linuxdo_client_id);
|
||||
}}
|
||||
disabled={
|
||||
(userState.user && userState.user.linuxdo_id !== '') ||
|
||||
!status.linuxdo_oauth
|
||||
}
|
||||
>
|
||||
{status.linuxdo_oauth ? '绑定' : '未启用'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>Telegram</Typography.Text>
|
||||
|
||||
@@ -77,6 +77,8 @@ const RegisterForm = () => {
|
||||
);
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
localStorage.removeItem('aff');
|
||||
|
||||
navigate('/login');
|
||||
showSuccess('注册成功!');
|
||||
} else {
|
||||
|
||||
@@ -10,6 +10,8 @@ import {
|
||||
} from 'semantic-ui-react';
|
||||
import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers';
|
||||
|
||||
import { useTheme } from '../context/Theme';
|
||||
|
||||
const SystemSetting = () => {
|
||||
let [inputs, setInputs] = useState({
|
||||
PasswordLoginEnabled: '',
|
||||
@@ -18,6 +20,10 @@ const SystemSetting = () => {
|
||||
GitHubOAuthEnabled: '',
|
||||
GitHubClientId: '',
|
||||
GitHubClientSecret: '',
|
||||
LinuxDoOAuthEnabled: '',
|
||||
LinuxDoClientId: '',
|
||||
LinuxDoClientSecret: '',
|
||||
LinuxDoMinLevel: 0,
|
||||
Notice: '',
|
||||
SMTPServer: '',
|
||||
SMTPPort: '',
|
||||
@@ -25,13 +31,13 @@ const SystemSetting = () => {
|
||||
SMTPFrom: '',
|
||||
SMTPToken: '',
|
||||
ServerAddress: '',
|
||||
EpayId: '',
|
||||
EpayKey: '',
|
||||
Price: 7.3,
|
||||
MinTopUp: 1,
|
||||
StripeApiSecret: '',
|
||||
StripeWebhookSecret: '',
|
||||
StripePriceId: '',
|
||||
PaymentEnabled: false,
|
||||
StripeUnitPrice: 8.0,
|
||||
MinTopUp: 5,
|
||||
TopupGroupRatio: '',
|
||||
PayAddress: '',
|
||||
CustomCallbackAddress: '',
|
||||
Footer: '',
|
||||
WeChatAuthEnabled: '',
|
||||
WeChatServerAddress: '',
|
||||
@@ -41,7 +47,9 @@ const SystemSetting = () => {
|
||||
TurnstileSiteKey: '',
|
||||
TurnstileSecretKey: '',
|
||||
RegisterEnabled: '',
|
||||
UserSelfDeletionEnabled: false,
|
||||
EmailDomainRestrictionEnabled: '',
|
||||
EmailAliasRestrictionEnabled: '',
|
||||
SMTPSSLEnabled: '',
|
||||
EmailDomainWhitelist: [],
|
||||
// telegram login
|
||||
@@ -56,6 +64,9 @@ const SystemSetting = () => {
|
||||
const [showPasswordWarningModal, setShowPasswordWarningModal] =
|
||||
useState(false);
|
||||
|
||||
const theme = useTheme();
|
||||
const isDark = theme === 'dark';
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
@@ -95,12 +106,16 @@ const SystemSetting = () => {
|
||||
case 'PasswordRegisterEnabled':
|
||||
case 'EmailVerificationEnabled':
|
||||
case 'GitHubOAuthEnabled':
|
||||
case 'LinuxDoOAuthEnabled':
|
||||
case 'WeChatAuthEnabled':
|
||||
case 'TelegramOAuthEnabled':
|
||||
case 'TurnstileCheckEnabled':
|
||||
case 'EmailDomainRestrictionEnabled':
|
||||
case 'EmailAliasRestrictionEnabled':
|
||||
case 'SMTPSSLEnabled':
|
||||
case 'RegisterEnabled':
|
||||
case 'UserSelfDeletionEnabled':
|
||||
case 'PaymentEnabled':
|
||||
value = inputs[key] === 'true' ? 'false' : 'true';
|
||||
break;
|
||||
default:
|
||||
@@ -115,9 +130,6 @@ const SystemSetting = () => {
|
||||
if (key === 'EmailDomainWhitelist') {
|
||||
value = value.split(',');
|
||||
}
|
||||
if (key === 'Price') {
|
||||
value = parseFloat(value);
|
||||
}
|
||||
setInputs((inputs) => ({
|
||||
...inputs,
|
||||
[key]: value,
|
||||
@@ -138,12 +150,16 @@ const SystemSetting = () => {
|
||||
name === 'Notice' ||
|
||||
(name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') ||
|
||||
name === 'ServerAddress' ||
|
||||
name === 'EpayId' ||
|
||||
name === 'EpayKey' ||
|
||||
name === 'Price' ||
|
||||
name === 'PayAddress' ||
|
||||
name === 'StripeApiSecret' ||
|
||||
name === 'StripeWebhookSecret' ||
|
||||
name === 'StripePriceId' ||
|
||||
name === 'StripeUnitPrice' ||
|
||||
name === 'MinTopUp' ||
|
||||
name === 'GitHubClientId' ||
|
||||
name === 'GitHubClientSecret' ||
|
||||
name === 'LinuxDoClientId' ||
|
||||
name === 'LinuxDoClientSecret' ||
|
||||
name === 'LinuxDoMinLevel' ||
|
||||
name === 'WeChatServerAddress' ||
|
||||
name === 'WeChatServerToken' ||
|
||||
name === 'WeChatAccountQRCodeImageURL' ||
|
||||
@@ -165,7 +181,7 @@ const SystemSetting = () => {
|
||||
await updateOption('ServerAddress', ServerAddress);
|
||||
};
|
||||
|
||||
const submitPayAddress = async () => {
|
||||
const submitPaymentConfig = async () => {
|
||||
if (inputs.ServerAddress === '') {
|
||||
showError('请先填写服务器地址');
|
||||
return;
|
||||
@@ -177,15 +193,31 @@ const SystemSetting = () => {
|
||||
}
|
||||
await updateOption('TopupGroupRatio', inputs.TopupGroupRatio);
|
||||
}
|
||||
let PayAddress = removeTrailingSlash(inputs.PayAddress);
|
||||
await updateOption('PayAddress', PayAddress);
|
||||
if (inputs.EpayId !== '') {
|
||||
await updateOption('EpayId', inputs.EpayId);
|
||||
let stripeApiSecret = removeTrailingSlash(inputs.StripeApiSecret);
|
||||
if (stripeApiSecret && !stripeApiSecret.startsWith('sk_')) {
|
||||
showError('输入了无效的Stripe API密钥');
|
||||
return;
|
||||
}
|
||||
if (inputs.EpayKey !== '') {
|
||||
await updateOption('EpayKey', inputs.EpayKey);
|
||||
stripeApiSecret && (await updateOption('StripeApiSecret', stripeApiSecret));
|
||||
|
||||
let stripeWebhookSecret = removeTrailingSlash(inputs.StripeWebhookSecret);
|
||||
if (stripeWebhookSecret && !stripeWebhookSecret.startsWith('whsec_')) {
|
||||
showError('输入了无效的Stripe Webhook签名密钥');
|
||||
return;
|
||||
}
|
||||
await updateOption('Price', '' + inputs.Price);
|
||||
stripeWebhookSecret &&
|
||||
(await updateOption('StripeWebhookSecret', stripeWebhookSecret));
|
||||
|
||||
let stripePriceId = removeTrailingSlash(inputs.StripePriceId);
|
||||
if (stripePriceId && !stripePriceId.startsWith('price_')) {
|
||||
showError('输入了无效的Stripe 物品价格ID');
|
||||
return;
|
||||
}
|
||||
await updateOption('StripePriceId', stripePriceId);
|
||||
|
||||
await updateOption('PaymentEnable', inputs.PaymentEnabled);
|
||||
await updateOption('StripeUnitPrice', inputs.StripeUnitPrice);
|
||||
await updateOption('MinTopUp', inputs.MinTopUp);
|
||||
};
|
||||
|
||||
const submitSMTP = async () => {
|
||||
@@ -261,6 +293,21 @@ const SystemSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const submitLinuxDoOAuth = async () => {
|
||||
if (originInputs['LinuxDoClientId'] !== inputs.LinuxDoClientId) {
|
||||
await updateOption('LinuxDoClientId', inputs.LinuxDoClientId);
|
||||
}
|
||||
if (
|
||||
originInputs['LinuxDoClientSecret'] !== inputs.LinuxDoClientSecret &&
|
||||
inputs.LinuxDoClientSecret !== ''
|
||||
) {
|
||||
await updateOption('LinuxDoClientSecret', inputs.LinuxDoClientSecret);
|
||||
}
|
||||
if (originInputs['LinuxDoMinLevel'] !== inputs.LinuxDoMinLevel) {
|
||||
await updateOption('LinuxDoMinLevel', inputs.LinuxDoMinLevel);
|
||||
}
|
||||
};
|
||||
|
||||
const submitTelegramSettings = async () => {
|
||||
// await updateOption('TelegramOAuthEnabled', inputs.TelegramOAuthEnabled);
|
||||
await updateOption('TelegramBotToken', inputs.TelegramBotToken);
|
||||
@@ -304,8 +351,10 @@ const SystemSetting = () => {
|
||||
return (
|
||||
<Grid columns={1}>
|
||||
<Grid.Column>
|
||||
<Form loading={loading}>
|
||||
<Header as='h3'>通用设置</Header>
|
||||
<Form loading={loading} inverted={isDark}>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
通用设置
|
||||
</Header>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.Input
|
||||
label='服务器地址'
|
||||
@@ -319,53 +368,73 @@ const SystemSetting = () => {
|
||||
更新服务器地址
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
支付设置(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)
|
||||
<Header as='h3' inverted={isDark}>
|
||||
支付设置(当前仅支持Stripe Checkout)
|
||||
<Header.Subheader>
|
||||
密钥、Webhook 等设置请
|
||||
<a
|
||||
href='https://dashboard.stripe.com/developers'
|
||||
target='_blank'
|
||||
rel='noreferrer'
|
||||
>
|
||||
点击此处
|
||||
</a>
|
||||
进行设置,最好先在
|
||||
<a
|
||||
href='https://dashboard.stripe.com/test/developers'
|
||||
target='_blank'
|
||||
rel='noreferrer'
|
||||
>
|
||||
测试环境
|
||||
</a>
|
||||
进行测试
|
||||
</Header.Subheader>
|
||||
</Header>
|
||||
<Message>
|
||||
Webhook 填:
|
||||
<code>{`${inputs.ServerAddress}/api/stripe/webhook`}</code>
|
||||
,需要包含事件:<code>checkout.session.completed</code> 和{' '}
|
||||
<code>checkout.session.expired</code>
|
||||
</Message>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.Input
|
||||
label='支付地址,不填写则不启用在线支付'
|
||||
placeholder='例如:https://yourdomain.com'
|
||||
value={inputs.PayAddress}
|
||||
name='PayAddress'
|
||||
label='API密钥'
|
||||
placeholder='sk_xxx的Stripe密钥,敏感信息不显示'
|
||||
value={inputs.StripeApiSecret}
|
||||
name='StripeApiSecret'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='易支付商户ID'
|
||||
placeholder='例如:0001'
|
||||
value={inputs.EpayId}
|
||||
name='EpayId'
|
||||
label='Webhook签名密钥'
|
||||
placeholder='whsec_xxx的Webhook签名密钥,敏感信息不显示'
|
||||
value={inputs.StripeWebhookSecret}
|
||||
name='StripeWebhookSecret'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='易支付商户密钥'
|
||||
placeholder='例如:dejhfueqhujasjmndbjkqaw'
|
||||
value={inputs.EpayKey}
|
||||
name='EpayKey'
|
||||
label='商品价格ID'
|
||||
placeholder='price_xxx的商品价格ID,新建产品后可获得'
|
||||
value={inputs.StripePriceId}
|
||||
name='StripePriceId'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.Input
|
||||
label='回调地址,不填写则使用上方服务器地址作为回调地址'
|
||||
placeholder='例如:https://yourdomain.com'
|
||||
value={inputs.CustomCallbackAddress}
|
||||
name='CustomCallbackAddress'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='充值价格(x元/美金)'
|
||||
placeholder='例如:7,就是7元/美金'
|
||||
value={inputs.Price}
|
||||
name='Price'
|
||||
label='商品单价(元)'
|
||||
placeholder='商品的人民币价格'
|
||||
value={inputs.StripeUnitPrice}
|
||||
name='StripeUnitPrice'
|
||||
type={'number'}
|
||||
min={0}
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='最低充值数量'
|
||||
placeholder='例如:2,就是最低充值2$'
|
||||
placeholder='例如:2,就是最低充值2件商品'
|
||||
value={inputs.MinTopUp}
|
||||
name='MinTopUp'
|
||||
type={'number'}
|
||||
min={1}
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
@@ -381,9 +450,21 @@ const SystemSetting = () => {
|
||||
placeholder='为一个 JSON 文本,键为组名称,值为倍率'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitPayAddress}>更新支付设置</Form.Button>
|
||||
<Form.Group inline>
|
||||
<Form.Button onClick={submitPaymentConfig}>
|
||||
更新支付设置
|
||||
</Form.Button>
|
||||
<Form.Checkbox
|
||||
checked={inputs.PaymentEnabled === 'true'}
|
||||
label='开启在线支付'
|
||||
name='PaymentEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider />
|
||||
<Header as='h3'>配置登录注册</Header>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置登录注册
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.PasswordLoginEnabled === 'true'}
|
||||
@@ -438,6 +519,12 @@ const SystemSetting = () => {
|
||||
name='GitHubOAuthEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.LinuxDoOAuthEnabled === 'true'}
|
||||
label='允许通过 LINUX DO 账户登录 & 注册'
|
||||
name='LinuxDoOAuthEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.WeChatAuthEnabled === 'true'}
|
||||
label='允许通过微信登录 & 注册'
|
||||
@@ -464,9 +551,15 @@ const SystemSetting = () => {
|
||||
name='TurnstileCheckEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.UserSelfDeletionEnabled === 'true'}
|
||||
label='允许用户自行删除账户'
|
||||
name='UserSelfDeletionEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置邮箱域名白名单
|
||||
<Header.Subheader>
|
||||
用以防止恶意用户利用临时邮箱批量注册
|
||||
@@ -480,6 +573,14 @@ const SystemSetting = () => {
|
||||
checked={inputs.EmailDomainRestrictionEnabled === 'true'}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Checkbox
|
||||
label='启用邮箱别名限制(例如:ab.cd@gmail.com)'
|
||||
name='EmailAliasRestrictionEnabled'
|
||||
onChange={handleInputChange}
|
||||
checked={inputs.EmailAliasRestrictionEnabled === 'true'}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={2}>
|
||||
<Form.Dropdown
|
||||
label='允许的邮箱域名'
|
||||
@@ -523,7 +624,7 @@ const SystemSetting = () => {
|
||||
保存邮箱域名白名单设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置 SMTP
|
||||
<Header.Subheader>用以支持系统的邮件发送</Header.Subheader>
|
||||
</Header>
|
||||
@@ -582,7 +683,7 @@ const SystemSetting = () => {
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitSMTP}>保存 SMTP 设置</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置 GitHub OAuth App
|
||||
<Header.Subheader>
|
||||
用以支持通过 GitHub 进行登录注册,
|
||||
@@ -625,6 +726,58 @@ const SystemSetting = () => {
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
配置 LINUX DO Oauth
|
||||
<Header.Subheader>
|
||||
用以支持通过 LINUX DO 进行登录注册,
|
||||
<a
|
||||
href='https://connect.linux.do'
|
||||
target='_blank'
|
||||
rel='noreferrer'
|
||||
>
|
||||
点击此处
|
||||
</a>
|
||||
管理你的 LINUX DO OAuth
|
||||
</Header.Subheader>
|
||||
</Header>
|
||||
<Message>
|
||||
Homepage URL 填 <code>{inputs.ServerAddress}</code>
|
||||
,Authorization callback URL 填{' '}
|
||||
<code>{`${inputs.ServerAddress}/oauth/linuxdo`}</code>
|
||||
</Message>
|
||||
<Form.Group widths={3}>
|
||||
<Form.Input
|
||||
label='LINUX DO Client ID'
|
||||
name='LinuxDoClientId'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.LinuxDoClientId}
|
||||
placeholder='输入你注册的 LINUX DO OAuth 的 ID'
|
||||
/>
|
||||
<Form.Input
|
||||
label='LINUX DO Client Secret'
|
||||
name='LinuxDoClientSecret'
|
||||
onChange={handleInputChange}
|
||||
type='password'
|
||||
autoComplete='new-password'
|
||||
value={inputs.LinuxDoClientSecret}
|
||||
placeholder='敏感信息不会发送到前端显示'
|
||||
/>
|
||||
<Form.Input
|
||||
label='限制最低信任等级'
|
||||
name='LinuxDoMinLevel'
|
||||
onChange={handleInputChange}
|
||||
type='number'
|
||||
min={0}
|
||||
max={4}
|
||||
value={inputs.LinuxDoMinLevel}
|
||||
placeholder='输入允许使用的最低 LINUX DO 信任等级'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={submitLinuxDoOAuth}>
|
||||
保存 LINUX DO OAuth 设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置 WeChat Server
|
||||
<Header.Subheader>
|
||||
用以支持通过微信进行登录注册,
|
||||
@@ -669,7 +822,9 @@ const SystemSetting = () => {
|
||||
保存 WeChat Server 设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>配置 Telegram 登录</Header>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置 Telegram 登录
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Input
|
||||
label='Telegram Bot Token'
|
||||
@@ -690,7 +845,7 @@ const SystemSetting = () => {
|
||||
保存 Telegram 登录设置
|
||||
</Form.Button>
|
||||
<Divider />
|
||||
<Header as='h3'>
|
||||
<Header as='h3' inverted={isDark}>
|
||||
配置 Turnstile
|
||||
<Header.Subheader>
|
||||
用以支持用户校验,
|
||||
|
||||
@@ -210,21 +210,39 @@ const UsersTable = () => {
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
<Popconfirm
|
||||
title='确定是否要删除此用户?'
|
||||
content='硬删除,此修改将不可逆'
|
||||
okType={'danger'}
|
||||
position={'left'}
|
||||
onConfirm={() => {
|
||||
manageUser(record.username, 'delete', record).then(() => {
|
||||
removeRecord(record.id);
|
||||
});
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='danger' style={{ marginRight: 1 }}>
|
||||
删除
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
{record.DeletedAt !== null ? (
|
||||
<Popconfirm
|
||||
title='确定是否要删除此用户?'
|
||||
content='硬删除,此修改将不可逆'
|
||||
okType={'danger'}
|
||||
position={'left'}
|
||||
onConfirm={() => {
|
||||
hardDeleteUser(record.id).then(() => {
|
||||
removeRecord(record.id);
|
||||
});
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='danger' style={{ marginRight: 1 }}>
|
||||
永久删除
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
) : (
|
||||
<Popconfirm
|
||||
title='确定是否要删除此用户?'
|
||||
content='软删除,数据依然留底'
|
||||
okType={'danger'}
|
||||
position={'left'}
|
||||
onConfirm={() => {
|
||||
manageUser(record.username, 'delete', record).then(() => {
|
||||
record.DeletedAt = new Date();
|
||||
});
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='danger' style={{ marginRight: 1 }}>
|
||||
删除
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
},
|
||||
@@ -235,6 +253,8 @@ const UsersTable = () => {
|
||||
const [activePage, setActivePage] = useState(1);
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
const [searching, setSearching] = useState(false);
|
||||
const [searchGroup, setSearchGroup] = useState('');
|
||||
const [groupOptions, setGroupOptions] = useState([]);
|
||||
const [userCount, setUserCount] = useState(ITEMS_PER_PAGE);
|
||||
const [showAddUser, setShowAddUser] = useState(false);
|
||||
const [showEditUser, setShowEditUser] = useState(false);
|
||||
@@ -298,6 +318,7 @@ const UsersTable = () => {
|
||||
.catch((reason) => {
|
||||
showError(reason);
|
||||
});
|
||||
fetchGroups().then();
|
||||
}, []);
|
||||
|
||||
const manageUser = async (username, action, record) => {
|
||||
@@ -318,6 +339,18 @@ const UsersTable = () => {
|
||||
setUsers(newUsers);
|
||||
} else {
|
||||
showError(message);
|
||||
throw new Error(message);
|
||||
}
|
||||
};
|
||||
|
||||
const hardDeleteUser = async (userId) => {
|
||||
const res = await API.delete('/api/user/' + userId);
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
showSuccess('操作成功完成!');
|
||||
} else {
|
||||
showError(message);
|
||||
throw new Error(message);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -340,15 +373,17 @@ const UsersTable = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const searchUsers = async () => {
|
||||
if (searchKeyword === '') {
|
||||
const searchUsers = async (searchKeyword, searchGroup) => {
|
||||
if (searchKeyword === '' && searchGroup === '') {
|
||||
// if keyword is blank, load files instead.
|
||||
await loadUsers(0);
|
||||
setActivePage(1);
|
||||
return;
|
||||
}
|
||||
setSearching(true);
|
||||
const res = await API.get(`/api/user/search?keyword=${searchKeyword}`);
|
||||
const res = await API.get(
|
||||
`/api/user/search?keyword=${searchKeyword}&group=${searchGroup}`,
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setUsers(data);
|
||||
@@ -409,6 +444,25 @@ const UsersTable = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const fetchGroups = async () => {
|
||||
try {
|
||||
let res = await API.get(`/api/group/`);
|
||||
// add 'all' option
|
||||
// res.data.data.unshift('all');
|
||||
if (res === undefined) {
|
||||
return;
|
||||
}
|
||||
setGroupOptions(
|
||||
res.data.data.map((group) => ({
|
||||
label: group,
|
||||
value: group,
|
||||
})),
|
||||
);
|
||||
} catch (error) {
|
||||
showError(error.message);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<AddUser
|
||||
@@ -422,17 +476,44 @@ const UsersTable = () => {
|
||||
handleClose={closeEditUser}
|
||||
editingUser={editingUser}
|
||||
></EditUser>
|
||||
<Form onSubmit={searchUsers}>
|
||||
<Form.Input
|
||||
label='搜索关键字'
|
||||
icon='search'
|
||||
field='keyword'
|
||||
iconPosition='left'
|
||||
placeholder='搜索用户的 ID,用户名,显示名称,以及邮箱地址 ...'
|
||||
value={searchKeyword}
|
||||
loading={searching}
|
||||
onChange={(value) => handleKeywordChange(value)}
|
||||
/>
|
||||
<Form
|
||||
onSubmit={() => {
|
||||
searchUsers(searchKeyword, searchGroup);
|
||||
}}
|
||||
labelPosition='left'
|
||||
>
|
||||
<div style={{ display: 'flex' }}>
|
||||
<Space>
|
||||
<Form.Input
|
||||
label='搜索关键字'
|
||||
icon='search'
|
||||
field='keyword'
|
||||
iconPosition='left'
|
||||
placeholder='搜索用户的 ID,用户名,显示名称,以及邮箱地址 ...'
|
||||
value={searchKeyword}
|
||||
loading={searching}
|
||||
onChange={(value) => handleKeywordChange(value)}
|
||||
/>
|
||||
<Form.Select
|
||||
field='group'
|
||||
label='分组'
|
||||
optionList={groupOptions}
|
||||
onChange={(value) => {
|
||||
setSearchGroup(value);
|
||||
searchUsers(searchKeyword, value);
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
label='查询'
|
||||
type='primary'
|
||||
htmlType='submit'
|
||||
className='btn-margin-right'
|
||||
style={{ marginRight: 8 }}
|
||||
>
|
||||
查询
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
</Form>
|
||||
|
||||
<Table
|
||||
|
||||
@@ -14,7 +14,11 @@ export async function getOAuthState() {
|
||||
export async function onGitHubOAuthClicked(github_client_id) {
|
||||
const state = await getOAuthState();
|
||||
if (!state) return;
|
||||
window.open(
|
||||
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`,
|
||||
);
|
||||
location.href = `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`;
|
||||
}
|
||||
|
||||
export async function onLinuxDoOAuthClicked(linuxdo_client_id) {
|
||||
const state = await getOAuthState();
|
||||
if (!state) return;
|
||||
location.href = `https://connect.linux.do/oauth2/authorize?client_id=${linuxdo_client_id}&response_type=code&state=${state}&scope=user:profile`;
|
||||
}
|
||||
|
||||
@@ -22,6 +22,13 @@ export const CHANNEL_OPTIONS = [
|
||||
color: 'indigo',
|
||||
label: 'Anthropic Claude',
|
||||
},
|
||||
{
|
||||
key: 33,
|
||||
text: 'AWS Claude',
|
||||
value: 33,
|
||||
color: 'indigo',
|
||||
label: 'AWS Claude',
|
||||
},
|
||||
{
|
||||
key: 3,
|
||||
text: 'Azure OpenAI',
|
||||
@@ -43,6 +50,13 @@ export const CHANNEL_OPTIONS = [
|
||||
color: 'orange',
|
||||
label: 'Google Gemini',
|
||||
},
|
||||
{
|
||||
key: 34,
|
||||
text: 'Cohere',
|
||||
value: 34,
|
||||
color: 'purple',
|
||||
label: 'Cohere',
|
||||
},
|
||||
{
|
||||
key: 15,
|
||||
text: '百度文心千帆',
|
||||
|
||||
36
web/src/context/Theme/index.js
Normal file
36
web/src/context/Theme/index.js
Normal file
@@ -0,0 +1,36 @@
|
||||
import { createContext, useCallback, useContext, useState } from 'react';
|
||||
|
||||
const ThemeContext = createContext(null);
|
||||
export const useTheme = () => useContext(ThemeContext);
|
||||
|
||||
const SetThemeContext = createContext(null);
|
||||
export const useSetTheme = () => useContext(SetThemeContext);
|
||||
|
||||
export const ThemeProvider = ({ children }) => {
|
||||
const [theme, _setTheme] = useState(() => {
|
||||
try {
|
||||
return localStorage.getItem('theme-mode') || null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
|
||||
const setTheme = useCallback((input) => {
|
||||
_setTheme(input ? 'dark' : 'light');
|
||||
|
||||
const body = document.body;
|
||||
if (!input) {
|
||||
body.removeAttribute('theme-mode');
|
||||
localStorage.setItem('theme-mode', 'light');
|
||||
} else {
|
||||
body.setAttribute('theme-mode', 'dark');
|
||||
localStorage.setItem('theme-mode', 'dark');
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<SetThemeContext.Provider value={setTheme}>
|
||||
<ThemeContext.Provider value={theme}>{children}</ThemeContext.Provider>
|
||||
</SetThemeContext.Provider>
|
||||
);
|
||||
};
|
||||
@@ -15,14 +15,18 @@ export function renderText(text, limit) {
|
||||
*/
|
||||
export function renderGroup(group) {
|
||||
if (group === '') {
|
||||
return <Tag size='large' key='default'>default</Tag>;
|
||||
return (
|
||||
<Tag size='large' key='default'>
|
||||
unknown
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
|
||||
const tagColors = {
|
||||
'vip': 'yellow',
|
||||
'pro': 'yellow',
|
||||
'svip': 'red',
|
||||
'premium': 'red'
|
||||
vip: 'yellow',
|
||||
pro: 'yellow',
|
||||
svip: 'red',
|
||||
premium: 'red',
|
||||
};
|
||||
|
||||
const groups = group.split(',').sort();
|
||||
@@ -97,12 +101,29 @@ export function getQuotaPerUnit() {
|
||||
return quotaPerUnit;
|
||||
}
|
||||
|
||||
export function renderUnitWithQuota(quota) {
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||
quota = parseFloat(quota);
|
||||
return quotaPerUnit * quota;
|
||||
}
|
||||
|
||||
export function getQuotaWithUnit(quota, digits = 6) {
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||
return (quota / quotaPerUnit).toFixed(digits);
|
||||
}
|
||||
|
||||
export function renderQuotaWithAmount(amount) {
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
displayInCurrency = displayInCurrency === 'true';
|
||||
if (displayInCurrency) {
|
||||
return '$' + amount;
|
||||
} else {
|
||||
return renderUnitWithQuota(amount);
|
||||
}
|
||||
}
|
||||
|
||||
export function renderQuota(quota, digits = 2) {
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
@@ -145,13 +166,12 @@ export const modelColorMap = {
|
||||
'dall-e': 'rgb(147,112,219)', // 深紫色
|
||||
'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调
|
||||
'dall-e-3': 'rgb(153,50,204)', // 介于紫罗兰和洋红之间的色调
|
||||
midjourney: 'rgb(136,43,180)', // 介于紫罗兰和洋红之间的色调
|
||||
'gpt-3.5-turbo': 'rgb(184,227,167)', // 浅绿色
|
||||
'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色
|
||||
'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿
|
||||
'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿
|
||||
'gpt-3.5-turbo-16k': 'rgb(252,200,149)', // 淡橙色
|
||||
'gpt-3.5-turbo-16k-0613': 'rgb(255,181,119)', // 淡桃色
|
||||
'gpt-3.5-turbo-16k': 'rgb(149,252,206)', // 淡橙色
|
||||
'gpt-3.5-turbo-16k-0613': 'rgb(119,255,214)', // 淡桃色
|
||||
'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色
|
||||
'gpt-4': 'rgb(135,206,235)', // 天蓝色
|
||||
'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色
|
||||
@@ -159,6 +179,8 @@ export const modelColorMap = {
|
||||
'gpt-4-1106-preview': 'rgb(30,144,255)', // 道奇蓝
|
||||
'gpt-4-0125-preview': 'rgb(2,177,236)', // 深天蓝
|
||||
'gpt-4-turbo-preview': 'rgb(2,177,255)', // 深天蓝
|
||||
'gpt-4-turbo': 'rgb(2,190,255)', // 深天蓝
|
||||
'gpt-4-turbo-2024-04-09': 'rgb(2,200,255)', // 深天蓝
|
||||
'gpt-4-32k': 'rgb(104,111,238)', // 中紫色
|
||||
'gpt-4-32k-0314': 'rgb(90,105,205)', // 暗灰蓝色
|
||||
'gpt-4-32k-0613': 'rgb(61,71,139)', // 暗蓝灰色
|
||||
@@ -180,6 +202,10 @@ export const modelColorMap = {
|
||||
'tts-1-hd': 'rgb(255,215,0)', // 金色
|
||||
'tts-1-hd-1106': 'rgb(255,223,0)', // 金黄色(略有区别)
|
||||
'whisper-1': 'rgb(245,245,220)', // 米色
|
||||
'claude-3-opus-20240229': 'rgb(255,132,31)', // 橙红色
|
||||
'claude-3-sonnet-20240229': 'rgb(253,135,93)', // 橙色
|
||||
'claude-3-haiku-20240307': 'rgb(255,175,146)', // 浅橙色
|
||||
'claude-2.1': 'rgb(255,209,190)', // 浅橙色(略有区别)
|
||||
};
|
||||
|
||||
export function stringToColor(str) {
|
||||
|
||||
@@ -79,6 +79,7 @@ export function showError(error) {
|
||||
switch (error.response.status) {
|
||||
case 401:
|
||||
// toast.error('错误:未登录或登录已过期,请重新登录!', showErrorOptions);
|
||||
localStorage.removeItem('user');
|
||||
window.location.href = '/login?expired=true';
|
||||
break;
|
||||
case 429:
|
||||
@@ -126,6 +127,10 @@ export function openPage(url) {
|
||||
}
|
||||
|
||||
export function removeTrailingSlash(url) {
|
||||
if (!url) {
|
||||
return '';
|
||||
}
|
||||
|
||||
if (url.endsWith('/')) {
|
||||
return url.slice(0, -1);
|
||||
} else {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user