Compare commits

..

25 Commits

Author SHA1 Message Date
wozulong
7d18a8e2a9 merge upstream
Signed-off-by: wozulong <>
2024-04-05 22:10:07 +08:00
wozulong
80af3718d0 lint fix
Signed-off-by: wozulong <>
2024-04-01 10:16:35 +08:00
wozulong
77ea6bec46 Merge remote-tracking branch 'upstream/main' 2024-04-01 10:15:48 +08:00
wozulong
c0ab8ae953 update github action
Signed-off-by: wozulong <>
2024-03-27 18:04:27 +08:00
wozulong
923c2dee32 merge upstream
Signed-off-by: wozulong <>
2024-03-27 18:01:43 +08:00
wozulong
ea17a46d8e fixed pagination issue in the log list
Signed-off-by: wozulong <>
2024-03-25 16:36:57 +08:00
wozulong
bfe9e5d25a Merge remote-tracking branch 'upstream/main' 2024-03-25 15:49:12 +08:00
wozulong
831ff47254 Merge remote-tracking branch 'upstream/main' 2024-03-24 00:18:02 +08:00
wozulong
11eaba6b5d fix gpt-3.5-turbo points
Signed-off-by: wozulong <>
2024-03-24 00:17:49 +08:00
wozulong
c2e4ec25c8 merge upstream
Signed-off-by: wozulong <>
2024-03-23 22:51:18 +08:00
wozulong
8537f10412 save stripe custome id
Signed-off-by: wozulong <>
2024-03-23 21:59:15 +08:00
wozulong
71d60eeef7 merge upstream
Signed-off-by: wozulong <>
2024-03-22 18:09:13 +08:00
wozulong
247ae0988f replace epay with stripe
Signed-off-by: wozulong <>
2024-03-22 18:00:20 +08:00
wozulong
0907fa6994 fix next uid
Signed-off-by: wozulong <>
2024-03-21 15:02:47 +08:00
wozulong
9855343aa8 Merge remote-tracking branch 'upstream/main' 2024-03-21 15:01:51 +08:00
wozulong
bd50fde268 Merge remote-tracking branch 'upstream/main' 2024-03-21 13:10:21 +08:00
wozulong
4267de5642 fix yarn timeout
Signed-off-by: wozulong <>
2024-03-20 19:14:17 +08:00
wozulong
8b55116563 update build files
Signed-off-by: wozulong <>
2024-03-20 18:31:22 +08:00
wozulong
f35e63e3f3 limit 'LINUX DO' trust level now available
Signed-off-by: wozulong <>
2024-03-20 16:54:38 +08:00
wozulong
17c409de23 update gh action
Signed-off-by: wozulong <>
2024-03-20 14:11:21 +08:00
wozulong
e4753e7411 synced with upstream
Signed-off-by: wozulong <>
2024-03-20 13:52:10 +08:00
paderlol
9adefa80b9 Optimizing Docker image builds (#1)
Co-authored-by: pader.zhang <pader.zhang@starlight-sms.com>
2024-03-14 23:10:05 -07:00
wozulong
4ce2381182 update issue template
Signed-off-by: wozulong <>
2024-03-14 20:02:22 +08:00
我秦始皇
62afc21ea5 Merge branch 'Calcium-Ion:main' into main 2024-03-14 03:56:31 -07:00
wozulong
7ddb7c586d 1. add LINUX DO oauth
2. fix oauth reg aff issue

Signed-off-by: wozulong <>
2024-03-14 18:53:54 +08:00
108 changed files with 4151 additions and 2589 deletions

View File

@@ -1,5 +1,5 @@
blank_issues_enabled: false
contact_links:
- name: 项目群聊
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
about: QQ 群629454374
- name: 交流社区
url: https://linux.do
about: 项目交流社区

View File

@@ -1,9 +1,6 @@
name: Publish Docker image (amd64)
on:
push:
tags:
- '*'
workflow_dispatch:
inputs:
name:
@@ -42,7 +39,7 @@ jobs:
uses: docker/metadata-action@v4
with:
images: |
calciumion/new-api
pengzhile/new-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images

View File

@@ -48,7 +48,7 @@ jobs:
uses: docker/metadata-action@v4
with:
images: |
calciumion/new-api
pengzhile/new-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images

View File

@@ -7,7 +7,7 @@ all: build-frontend start-backend
build-frontend:
@echo "Building frontend..."
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
@cd $(FRONTEND_DIR) && yarn install --network-timeout 1000000 && DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) yarn build
start-backend:
@echo "Starting backend dev server..."

View File

@@ -60,8 +60,8 @@
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,建议开启缓存功能。
如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
渠道重试功能已经实现,可以在渠道管理中设置重试次数,需要开启缓存功能,否则只会使用同优先级重试
如果开启了缓存功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`

View File

@@ -9,6 +9,15 @@ import (
"github.com/google/uuid"
)
// Pay Settings
var StripeApiSecret = ""
var StripeWebhookSecret = ""
var StripePriceId = ""
var PaymentEnabled = false
var StripeUnitPrice = 8.0
var MinTopUp = 5
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "New API"
@@ -41,13 +50,14 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var LinuxDoOAuthEnabled = false
var WeChatAuthEnabled = false
var TelegramOAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var UserSelfDeletionEnabled = false
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
@@ -75,6 +85,10 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var LinuxDoClientId = ""
var LinuxDoClientSecret = ""
var LinuxDoMinLevel = 0
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
@@ -176,6 +190,12 @@ const (
ChannelStatusAutoDisabled = 3
)
const (
TopUpStatusPending = "pending"
TopUpStatusSuccess = "success"
TopUpStatusExpired = "expired"
)
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
@@ -206,10 +226,6 @@ const (
ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{
@@ -245,7 +261,4 @@ var ChannelBaseURLs = []string{
"", //29
"", //30
"https://api.lingyiwanwu.com", //31
"", //32
"", //33
"https://api.cohere.ai", //34
}

View File

@@ -16,22 +16,7 @@ func SafeGoroutine(f func()) {
}()
}
func SafeSendBool(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = true
}
}()
// This will panic if the channel is closed.
ch <- value
// If the code reaches here, then the channel was not closed.
return false
}
func SafeSendString(ch chan string, value string) (closed bool) {
func SafeSend(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {

84
common/hash.go Normal file
View File

@@ -0,0 +1,84 @@
package common
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"math/rand"
"time"
)
func Sha256Raw(data string) []byte {
h := sha256.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1Raw(data []byte) []byte {
h := sha1.New()
h.Write([]byte(data))
return h.Sum(nil)
}
func Sha1(data string) string {
return hex.EncodeToString(Sha1Raw([]byte(data)))
}
func HmacSha256Raw(message, key []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(message)
return h.Sum(nil)
}
func HmacSha256(message, key string) string {
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
}
func RandomBytes(length int) []byte {
rand.Seed(time.Now().UnixNano())
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return b
}
func RandomString(length int) string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomHex(length int) string {
const chars = "abcdef0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomNumber(length int) string {
const chars = "0123456789"
result := make([]byte, length)
randomBytes := RandomBytes(length)
for i := 0; i < length; i++ {
result[i] = chars[randomBytes[i]%byte(len(chars))]
}
return string(result)
}
func RandomUUID() string {
all := RandomHex(32)
return all[:8] + "-" + all[8:12] + "-" + all[12:16] + "-" + all[16:20] + "-" + all[20:]
}

View File

@@ -2,7 +2,6 @@ package common
import (
"context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
@@ -100,12 +99,10 @@ func LogQuota(quota int) string {
}
}
// LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) {
jsonStr, err := json.Marshal(obj)
if err != nil {
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
return
func LogQuotaF(quota float64) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", quota/QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", int64(quota))
}
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
}

View File

@@ -12,108 +12,94 @@ import (
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var DefaultModelRatio = map[string]float64{
//"midjourney": 50,
"gpt-4-gizmo-*": 15,
"gpt-4": 15,
//"gpt-4-0314": 15, //deprecated
"gpt-4-0613": 15,
"gpt-4-32k": 30,
//"gpt-4-32k-0314": 30, //deprecated
"gpt-4-gizmo-*": 15,
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
"gpt-4-32k": 30,
"gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens
//"gpt-3.5-turbo-0301": 0.75, //deprecated
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25,
"babbage-002": 0.2, // $0.0004 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8,
"dall-e-3": 16,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // 0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro-latest": 1,
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 7.143, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // 0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // 0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // 0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // 0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // 0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25,
"babbage-002": 0.2, // $0.0004 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
"text-davinci-002": 10,
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-1106": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"tts-1-hd-1106": 15, // 1k characters -> $0.03
"davinci": 10,
"curie": 10,
"babbage": 10,
"ada": 10,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-embedding-ada-002": 0.05,
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"dall-e-2": 8,
"dall-e-3": 16,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.0-pro-vision-001": 1,
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 7.143, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// https://platform.lingyiwanwu.com/docs#-计费单元
// 已经按照 7.2 来换算美元价格
"yi-34b-chat-0205": 0.018,
"yi-34b-chat-200k": 0.0864,
"yi-vl-plus": 0.0432,
"command": 0.5,
"command-nightly": 0.5,
"command-light": 0.5,
"command-light-nightly": 0.5,
"command-r": 0.25,
"command-r-plus ": 1.5,
"deepseek-chat": 0.07,
"deepseek-coder": 0.07,
"yi-34b-chat-0205": 0.018,
"yi-34b-chat-200k": 0.0864,
"yi-vl-plus": 0.0432,
}
var DefaultModelPrice = map[string]float64{
@@ -138,12 +124,6 @@ var DefaultModelPrice = map[string]float64{
var modelPrice map[string]float64 = nil
var modelRatio map[string]float64 = nil
var CompletionRatio map[string]float64 = nil
var DefaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
"gpt-4-all": 2,
}
func ModelPrice2JSONString() string {
if modelPrice == nil {
modelPrice = DefaultModelPrice
@@ -208,45 +188,31 @@ func GetModelRatio(name string) float64 {
return ratio
}
func CompletionRatio2JSONString() string {
if CompletionRatio == nil {
CompletionRatio = DefaultCompletionRatio
}
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
SysError("error marshalling completion ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-3.5") {
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
if strings.HasSuffix(name, "0125") {
return 3
}
if strings.HasSuffix(name, "1106") {
return 2
}
return 4.0 / 3.0
if name == "gpt-3.5-turbo" {
return 3
}
return 1.333333
}
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && !strings.HasPrefix(name, "gpt-4-gizmo") {
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
return 3
}
return 2
}
if strings.Contains(name, "claude-instant-1") {
if strings.HasPrefix(name, "claude-instant-1") {
return 3
} else if strings.Contains(name, "claude-2") {
} else if strings.HasPrefix(name, "claude-2") {
return 3
} else if strings.Contains(name, "claude-3") {
} else if strings.HasPrefix(name, "claude-3") {
return 5
}
if strings.HasPrefix(name, "mistral-") {
@@ -255,29 +221,9 @@ func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gemini-") {
return 3
}
if strings.HasPrefix(name, "command") {
switch name {
case "command-r":
return 3
case "command-r-plus":
return 5
default:
return 2
}
}
if strings.HasPrefix(name, "deepseek") {
return 2
}
switch name {
case "llama2-70b-4096":
return 0.8 / 0.64
case "llama3-8b-8192":
return 2
case "llama3-70b-8192":
return 0.79 / 0.59
}
if ratio, ok := CompletionRatio[name]; ok {
return ratio
return 0.8 / 0.7
}
return 1
}

View File

@@ -1,7 +1,6 @@
package common
import (
"encoding/json"
"fmt"
"github.com/google/uuid"
"html/template"
@@ -242,11 +241,3 @@ func RandomSleep() {
// Sleep for 0-3000 ms
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
}
func MapToJsonStr(m map[string]interface{}) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}

View File

@@ -1,9 +1,6 @@
package constant
var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true
const (
MjErrorUnknown = 5

View File

@@ -1,8 +0,0 @@
package constant
var PayAddress = ""
var CustomCallbackAddress = ""
var EpayId = ""
var EpayKey = ""
var Price = 7.3
var MinTopUp = 1

View File

@@ -86,7 +86,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
if err != nil {
return err, nil
}
if resp != nil && resp.StatusCode != http.StatusOK {
if resp.StatusCode != http.StatusOK {
err := relaycommon.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
@@ -208,7 +208,7 @@ func testAllChannels(notify bool) error {
if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
service.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)

View File

@@ -123,6 +123,8 @@ func GitHubOAuth(c *gin.Context) {
}
} else {
if common.RegisterEnabled {
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
@@ -133,7 +135,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil {
if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),

239
controller/linuxdo.go Normal file
View File

@@ -0,0 +1,239 @@
package controller
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
type LinuxDoOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type LinuxDoUser struct {
ID int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
form := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
}
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", "Basic "+auth)
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse LinuxDoOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
}
defer res2.Body.Close()
var linuxdoUser LinuxDoUser
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
if err != nil {
return nil, err
}
if linuxdoUser.ID == 0 {
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
}
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
return nil, errors.New("用户 LINUX DO 信任等级不足!")
}
return &linuxdoUser, nil
}
func LinuxDoOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
err := user.FillUserByLinuxDoId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
affCode := c.Query("aff")
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
if linuxdoUser.Name != "" {
user.DisplayName = linuxdoUser.Name
} else {
user.DisplayName = linuxdoUser.Username
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(user.InviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDoOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 LINUX DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
LinuxDoLevel: linuxdoUser.TrustLevel,
}
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 LINUX DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
user.LinuxDoLevel = linuxdoUser.TrustLevel
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -24,7 +24,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*pageSize, pageSize, channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -35,6 +35,7 @@ func GetAllLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"total": total,
"data": logs,
})
return
@@ -58,7 +59,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -69,6 +70,7 @@ func GetUserLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"total": total,
"data": logs,
})
return

View File

@@ -10,11 +10,11 @@ import (
"log"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/service"
"strconv"
"strings"
"time"
)
@@ -86,7 +86,7 @@ func UpdateMidjourneyTaskBulk() {
continue
}
// 设置超时时间
timeout := time.Second * 15
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
@@ -233,12 +233,6 @@ func GetAllMidjourney(c *gin.Context) {
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if constant.MjForwardUrlEnabled {
for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
@@ -265,7 +259,7 @@ func GetUserMidjourney(c *gin.Context) {
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if constant.MjForwardUrlEnabled {
if !strings.Contains(common.ServerAddress, "localhost") {
for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney

View File

@@ -38,6 +38,8 @@ func GetStatus(c *gin.Context) {
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
"linuxdo_client_id": common.LinuxDoClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
@@ -46,8 +48,8 @@ func GetStatus(c *gin.Context) {
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress,
"price": constant.Price,
"min_topup": constant.MinTopUp,
"stripe_unit_price": common.StripeUnitPrice,
"min_topup": common.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
@@ -60,7 +62,7 @@ func GetStatus(c *gin.Context) {
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "",
"payment_enabled": common.PaymentEnabled,
"mj_notify_enabled": constant.MjNotifyEnabled,
},
})
@@ -120,17 +122,12 @@ func SendEmailVerification(c *gin.Context) {
})
return
}
parts := strings.Split(email, "@")
if len(parts) != 2 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
localPart := parts[0]
domainPart := parts[1]
if common.EmailDomainRestrictionEnabled {
parts := strings.Split(email, "@")
localPart := parts[0]
domainPart := parts[1]
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Count(localPart, ".") > 1
allowed := false
for _, domain := range common.EmailDomainWhitelist {
if domainPart == domain {
@@ -138,7 +135,13 @@ func SendEmailVerification(c *gin.Context) {
break
}
}
if !allowed {
if allowed && !containsSpecialSymbols {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Your email address is allowed.",
})
return
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
@@ -146,17 +149,6 @@ func SendEmailVerification(c *gin.Context) {
return
}
}
if common.EmailAliasRestrictionEnabled {
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Contains(localPart, ".")
if containsSpecialSymbols {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
})
return
}
}
if model.IsEmailAlreadyTaken(email) {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -4,15 +4,13 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay"
"one-api/relay/channel/ai360"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
"one-api/relay/channel/lingyiwanwu"
relayconstant "one-api/relay/constant"
)
@@ -45,9 +43,8 @@ type OpenAIModels struct {
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func getPermission() []OpenAIModelPermission {
func init() {
var permission []OpenAIModelPermission
permission = append(permission, OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
@@ -63,12 +60,7 @@ func getPermission() []OpenAIModelPermission {
Group: nil,
IsBlocking: false,
})
return permission
}
func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
permission := getPermission()
for i := 0; i < relayconstant.APITypeDummy; i++ {
if i == relayconstant.APITypeAIProxyLibrary {
continue
@@ -93,7 +85,7 @@ func init() {
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: ai360.ChannelName,
OwnedBy: "360",
Permission: permission,
Root: modelName,
Parent: nil,
@@ -136,17 +128,6 @@ func init() {
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
}
channelId2Models = make(map[int][]string)
for i := 1; i <= common.ChannelTypeDummy; i++ {
apiType := relayconstant.ChannelType2APIType(i)
if apiType == -1 || apiType == relayconstant.APITypeAIProxyLibrary {
continue
}
meta := &relaycommon.RelayInfo{ChannelType: i}
adaptor := relay.GetAdaptor(apiType)
adaptor.Init(meta, dto.GeneralOpenAIRequest{})
channelId2Models[i] = adaptor.GetModelList()
}
}
func ListModels(c *gin.Context) {
@@ -161,39 +142,21 @@ func ListModels(c *gin.Context) {
}
models := model.GetGroupModels(user.Group)
userOpenAiModels := make([]OpenAIModels, 0)
permission := getPermission()
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
} else {
userOpenAiModels = append(userOpenAiModels, OpenAIModels{
Id: s,
Object: "model",
Created: 1626777600,
OwnedBy: "openai",
Permission: permission,
Root: s,
Parent: nil,
})
}
}
c.JSON(200, gin.H{
"success": true,
"data": userOpenAiModels,
"object": "list",
"data": userOpenAiModels,
})
}
func ChannelListModels(c *gin.Context) {
c.JSON(200, gin.H{
"success": true,
"data": openAIModels,
})
}
func DashboardListModels(c *gin.Context) {
c.JSON(200, gin.H{
"success": true,
"data": channelId2Models,
"object": "list",
"data": openAIModels,
})
}

View File

@@ -14,7 +14,7 @@ func GetOptions(c *gin.Context) {
var options []*model.Option
common.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue
}
options = append(options, &model.Option{
@@ -50,6 +50,14 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "LinuxDoOAuthEnabled":
if option.Value == "true" && common.LinuxDoClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 LINUX DO OAuth请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret",
})
return
}
case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{

View File

@@ -15,7 +15,6 @@ import (
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
)
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
@@ -43,7 +42,7 @@ func Relay(c *gin.Context) {
group := c.GetString("group")
originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode)
useChannel := []int{channelId}
retryLogStr := fmt.Sprintf("重试:%d", channelId)
if openaiErr != nil {
go processChannelError(c, channelId, openaiErr)
} else {
@@ -56,7 +55,7 @@ func Relay(c *gin.Context) {
break
}
channelId = channel.Id
useChannel = append(useChannel, channelId)
retryLogStr += fmt.Sprintf("->%d", channel.Id)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
@@ -67,10 +66,7 @@ func Relay(c *gin.Context) {
go processChannelError(c, channelId, openaiErr)
}
}
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)
}
common.LogInfo(c.Request.Context(), retryLogStr)
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
@@ -96,9 +92,6 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
if openaiErr.StatusCode == http.StatusTooManyRequests {
return true
}
if openaiErr.StatusCode == 307 {
return true
}
if openaiErr.StatusCode/100 == 5 {
// 超时不重试
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
@@ -109,10 +102,6 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
if openaiErr.StatusCode == http.StatusBadRequest {
return false
}
if openaiErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
if openaiErr.LocalError {
return false
}
@@ -124,7 +113,7 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message)
@@ -160,7 +149,7 @@ func RelayMidjourney(c *gin.Context) {
"code": err.Code,
})
channelId := c.GetInt("channel_id")
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
}
}

97
controller/stripe.go Normal file
View File

@@ -0,0 +1,97 @@
package controller
import (
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/webhook"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
)
func StripeWebhook(c *gin.Context) {
payload, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
c.AbortWithStatus(http.StatusServiceUnavailable)
return
}
signature := c.GetHeader("Stripe-Signature")
endpointSecret := common.StripeWebhookSecret
event, err := webhook.ConstructEvent(payload, signature, endpointSecret)
if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err)
c.AbortWithStatus(http.StatusBadRequest)
return
}
switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(event)
case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(event)
default:
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
}
c.Status(http.StatusOK)
}
func sessionCompleted(event stripe.Event) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "complete" != status {
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
return
}
err := model.Recharge(referenceId, customerId)
if err != nil {
log.Println(err.Error(), referenceId)
return
}
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency"))
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
}
func sessionExpired(event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "expired" != status {
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
return
}
if "" == referenceId {
log.Println("未提供支付单号")
return
}
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Println("充值订单不存在", referenceId)
return
}
if topUp.Status != common.TopUpStatusPending {
log.Println("充值订单状态错误", referenceId)
}
topUp.Status = common.TopUpStatusExpired
err := topUp.Update()
if err != nil {
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
return
}
log.Println("充值订单已过期", referenceId)
}

View File

@@ -1,23 +1,20 @@
package controller
import "C"
import (
"fmt"
"github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"one-api/constant"
"github.com/stripe/stripe-go/v76"
"github.com/stripe/stripe-go/v76/checkout/session"
"log"
"net/url"
"one-api/common"
"one-api/model"
"one-api/service"
"strconv"
"sync"
"strings"
"time"
)
type EpayRequest struct {
type PayRequest struct {
Amount int `json:"amount"`
PaymentMethod string `json:"payment_method"`
TopUpCode string `json:"top_up_code"`
@@ -28,196 +25,114 @@ type AmountRequest struct {
TopUpCode string `json:"top_up_code"`
}
func GetEpayClient() *epay.Client {
if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" {
return nil
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
if !strings.HasPrefix(common.StripeApiSecret, "sk_") {
return "", fmt.Errorf("无效的Stripe API密钥")
}
withUrl, err := epay.NewClient(&epay.Config{
PartnerID: constant.EpayId,
Key: constant.EpayKey,
}, constant.PayAddress)
stripe.Key = common.StripeApiSecret
params := &stripe.CheckoutSessionParams{
ClientReferenceID: stripe.String(referenceId),
SuccessURL: stripe.String(common.ServerAddress + "/log"),
CancelURL: stripe.String(common.ServerAddress + "/topup"),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(common.StripePriceId),
Quantity: stripe.Int64(amount),
},
},
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
}
if "" == customerId {
if "" != email {
params.CustomerEmail = stripe.String(email)
}
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
} else {
params.Customer = stripe.String(customerId)
}
result, err := session.New(params)
if err != nil {
return nil
return "", err
}
return withUrl
return result.URL, nil
}
func getPayMoney(amount float64, user model.User) float64 {
if !common.DisplayInCurrencyEnabled {
amount = amount / common.QuotaPerUnit
}
// 别问为什么用float64问就是这么点钱没必要
topupGroupRatio := common.GetTopupGroupRatio(user.Group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
payMoney := amount * constant.Price * topupGroupRatio
return payMoney
func GetPayAmount(count float64) float64 {
return count * common.StripeUnitPrice
}
func getMinTopup() int {
minTopup := constant.MinTopUp
if !common.DisplayInCurrencyEnabled {
minTopup = minTopup * int(common.QuotaPerUnit)
func GetChargedAmount(count float64, user model.User) float64 {
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
if topUpGroupRatio == 0 {
topUpGroupRatio = 1
}
return minTopup
return count * topUpGroupRatio
}
func RequestEpay(c *gin.Context) {
var req EpayRequest
func RequestPayLink(c *gin.Context) {
var req PayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
return
}
if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
if !common.PaymentEnabled {
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
return
}
if req.PaymentMethod != "stripe" {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
return
}
if req.Amount > 10000 {
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
payMoney := getPayMoney(float64(req.Amount), *user)
if payMoney < 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
var payType epay.PurchaseType
if req.PaymentMethod == "zfb" {
payType = epay.Alipay
}
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = epay.WechatPay
}
callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(common.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
client := GetEpayClient()
if client == nil {
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType,
ServiceTradeNo: "A" + tradeNo,
Name: "B" + tradeNo,
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
Device: epay.PC,
NotifyUrl: notifyUrl,
ReturnUrl: returnUrl,
})
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), common.RandomString(4))
referenceId := "ref_" + common.Sha1(reference)
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, int64(req.Amount))
if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
amount := req.Amount
if !common.DisplayInCurrencyEnabled {
amount = amount / int(common.QuotaPerUnit)
}
topUp := &model.TopUp{
UserId: id,
Amount: amount,
Money: payMoney,
TradeNo: "A" + tradeNo,
Amount: req.Amount,
Money: chargedMoney,
TradeNo: referenceId,
CreateTime: time.Now().Unix(),
Status: "pending",
Status: common.TopUpStatusPending,
}
err = topUp.Insert()
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
return
}
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
}
// tradeNo lock
var orderLocks sync.Map
var createLock sync.Mutex
// LockOrder 尝试对给定订单号加锁
func LockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if !ok {
createLock.Lock()
defer createLock.Unlock()
lock, ok = orderLocks.Load(tradeNo)
if !ok {
lock = new(sync.Mutex)
orderLocks.Store(tradeNo, lock)
}
}
lock.(*sync.Mutex).Lock()
}
// UnlockOrder 释放给定订单号的锁
func UnlockOrder(tradeNo string) {
lock, ok := orderLocks.Load(tradeNo)
if ok {
lock.(*sync.Mutex).Unlock()
}
}
func EpayNotify(c *gin.Context) {
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
client := GetEpayClient()
if client == nil {
log.Println("易支付回调失败 未找到配置信息")
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
return
}
}
verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus {
_, err := c.Writer.Write([]byte("success"))
if err != nil {
log.Println("易支付回调写入失败")
}
} else {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
}
log.Println("易支付回调签名验证失败")
return
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
LockOrder(verifyInfo.ServiceTradeNo)
defer UnlockOrder(verifyInfo.ServiceTradeNo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp == nil {
log.Printf("易支付回调未找到订单: %v", verifyInfo)
return
}
if topUp.Status == "pending" {
topUp.Status = "success"
err := topUp.Update()
if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp)
return
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
}
c.JSON(200, gin.H{
"message": "success",
"data": gin.H{
"payLink": payLink,
},
})
}
func RequestAmount(c *gin.Context) {
@@ -227,17 +142,23 @@ func RequestAmount(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return
}
if req.Amount < getMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
if !common.PaymentEnabled {
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
return
}
if req.Amount < common.MinTopUp {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
payMoney := getPayMoney(float64(req.Amount), *user)
if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
payMoney := GetPayAmount(float64(req.Amount))
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
c.JSON(200, gin.H{
"message": "success",
"data": gin.H{
"payAmount": strconv.FormatFloat(payMoney, 'f', 2, 64),
"chargedAmount": strconv.FormatFloat(chargedMoney, 'f', 2, 64),
},
})
}

View File

@@ -66,6 +66,7 @@ func setupLogin(user *model.User, c *gin.Context) {
session.Set("username", user.Username)
session.Set("role", user.Role)
session.Set("status", user.Status)
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
err := session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -216,8 +217,7 @@ func GetAllUsers(c *gin.Context) {
func SearchUsers(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
users, err := model.SearchUsers(keyword, group)
users, err := model.SearchUsers(keyword)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -453,7 +453,7 @@ func UpdateUser(c *gin.Context) {
updatedUser.Password = "" // rollback to what it should be
}
updatePassword := updatedUser.Password != ""
if err := updatedUser.Edit(updatePassword); err != nil {
if err := updatedUser.Update(updatePassword); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -517,7 +517,7 @@ func UpdateSelf(c *gin.Context) {
return
}
func DeleteUser(c *gin.Context) {
func HardDeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -526,7 +526,7 @@ func DeleteUser(c *gin.Context) {
})
return
}
originUser, err := model.GetUserById(id, false)
originUser, err := model.GetUserByIdUnscoped(id, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -550,9 +550,23 @@ func DeleteUser(c *gin.Context) {
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteSelf(c *gin.Context) {
if !common.UserSelfDeletionEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "当前设置不允许用户自我删除账号",
})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
@@ -726,7 +740,7 @@ func ManageUser(c *gin.Context) {
user.Role = common.RoleCommonUser
}
if err := user.Update(false); err != nil {
if err := user.UpdateAll(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),

View File

@@ -2,18 +2,17 @@ version: '3.4'
services:
new-api:
image: calciumion/new-api:latest
# build: .
image: pengzhile/new-api:latest
container_name: new-api
restart: always
command: --log-dir /app/logs
ports:
- "3000:3000"
volumes:
- ./data:/data
- ./data/new-api:/data
- ./logs:/app/logs
environment:
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai
@@ -23,13 +22,22 @@ services:
depends_on:
- redis
healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s
timeout: 10s
retries: 3
- db
redis:
image: redis:latest
container_name: redis
restart: always
db:
image: mysql:8.2.0
container_name: mysql
restart: always
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: newapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: new-api # 自动创建数据库

View File

@@ -32,21 +32,6 @@ type GeneralOpenAIRequest struct {
TopLogProbs int `json:"top_logprobs,omitempty"`
}
type OpenAITools struct {
Type string `json:"type"`
Function OpenAIFunction `json:"function"`
}
type OpenAIFunction struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
return int64(r.MaxTokens)
}
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil

View File

@@ -54,54 +54,21 @@ type OpenAIEmbeddingResponse struct {
}
type ChatCompletionsStreamResponseChoice struct {
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
Logprobs *any `json:"logprobs"`
FinishReason *string `json:"finish_reason"`
Index int `json:"index"`
}
type ChatCompletionsStreamResponseChoiceDelta struct {
Content *string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
return c.Content == nil && len(c.ToolCalls) == 0
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
c.Content = &s
}
func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
if c.Content == nil {
return ""
}
return *c.Content
}
type ToolCall struct {
// Index is not nil only in chat completion chunk object
Index *int `json:"index,omitempty"`
ID string `json:"id"`
Type any `json:"type"`
Function FunctionCall `json:"function"`
}
type FunctionCall struct {
Name string `json:"name,omitempty"`
// call function with arguments in JSON format
Arguments string `json:"arguments,omitempty"`
Delta struct {
Content string `json:"content"`
Role string `json:"role,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
Index int `json:"index,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint *string `json:"system_fingerprint"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type ChatCompletionsStreamResponseSimple struct {

18
go.mod
View File

@@ -4,11 +4,7 @@ module one-api
go 1.18
require (
github.com/Calcium-Ion/go-epay v0.0.2
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.26.1
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
@@ -19,11 +15,10 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.4.0
github.com/pkg/errors v0.9.1
github.com/pkoukk/tiktoken-go v0.1.6
github.com/samber/lo v1.39.0
github.com/samber/lo v1.38.1
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.21.0
golang.org/x/image v0.15.0
gorm.io/driver/mysql v1.4.3
@@ -34,10 +29,6 @@ require (
require (
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
@@ -68,15 +59,16 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/stripe/stripe-go/v76 v76.21.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect

41
go.sum
View File

@@ -1,23 +1,7 @@
github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@@ -78,8 +62,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -99,8 +83,6 @@ github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@@ -146,8 +128,6 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -155,10 +135,12 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 h1:avbt5a8F/zbYwFzTugrqWOBJe/K1cJj6+xpr+x1oVAI=
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2/go.mod h1:SiffGCWGGMVwujne2dUQbJ5zUVD1V1Yj0hDuTfqFNEo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -171,6 +153,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stripe/stripe-go/v76 v76.21.0 h1:O3GHImHS4oUI3qWMOClHN3zAQF5/oswS/NB7leV1fsU=
github.com/stripe/stripe-go/v76 v76.21.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@@ -191,18 +175,20 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -220,6 +206,7 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

View File

@@ -15,6 +15,7 @@ func authHelper(c *gin.Context, minRole int) {
role := session.Get("role")
id := session.Get("id")
status := session.Get("status")
linuxDoEnable := session.Get("linuxdo_enable")
if username == nil {
// Check access token
accessToken := c.Request.Header.Get("Authorization")
@@ -33,6 +34,7 @@ func authHelper(c *gin.Context, minRole int) {
role = user.Role
id = user.Id
status = user.Status
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -50,6 +52,14 @@ func authHelper(c *gin.Context, minRole int) {
c.Abort()
return
}
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户 LINUX DO 信任等级不足",
})
c.Abort()
return
}
if role.(int) < minRole {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -112,6 +122,15 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return
}
if !linuxDoEnabled {
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
return
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)

View File

@@ -24,9 +24,6 @@ func Distribute() func(c *gin.Context) {
userId := c.GetInt("id")
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -43,7 +40,72 @@ func Distribute() func(c *gin.Context) {
return
}
} else {
shouldSelectChannel := true
// Select a channel for the user
var modelRequest ModelRequest
var err error
if strings.Contains(c.Request.URL.Path, "/mj/") {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return
} else {
// task fetch, task fetch by condition, notify
shouldSelectChannel = false
}
}
modelRequest.Model = midjourneyModel
}
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
}
}
// check token model mapping
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
@@ -66,6 +128,8 @@ func Distribute() func(c *gin.Context) {
}
}
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
if err != nil {
@@ -83,87 +147,14 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
}
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
}
}
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Next()
}
}
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
var modelRequest ModelRequest
shouldSelectChannel := true
var err error
if strings.Contains(c.Request.URL.Path, "/mj/") {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
shouldSelectChannel = false
} else {
midjourneyRequest := dto.MidjourneyRequest{}
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return nil, false, err
}
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return nil, false, fmt.Errorf(mjErr.Description)
}
if midjourneyModel == "" {
if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else {
// task fetch, task fetch by condition, notify
shouldSelectChannel = false
}
}
modelRequest.Model = midjourneyModel
}
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, err
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
}
}
return &modelRequest, shouldSelectChannel, nil
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("original_model", modelName) // for retry
if channel == nil {
return
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
@@ -177,7 +168,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
}
c.Set("auto_ban", ban)
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Set("original_model", modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一

View File

@@ -3,8 +3,6 @@ package model
import (
"errors"
"fmt"
"github.com/samber/lo"
"gorm.io/gorm"
"one-api/common"
"strings"
)
@@ -29,63 +27,18 @@ func GetGroupModels(group string) []string {
return models
}
func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil {
// 处理错误
return 0, err
}
// 确定要使用的优先级
var priorityToUse int
if retry >= len(priorities) {
// 如果重试次数大于优先级数,则使用最小的优先级
priorityToUse = priorities[len(priorities)-1]
} else {
priorityToUse = priorities[retry]
}
return priorityToUse, nil
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
}
}
return channelQuery
}
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
var abilities []Ability
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var err error = nil
channelQuery := getChannelQuery(group, model, retry)
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else {
@@ -135,16 +88,7 @@ func (channel *Channel) AddAbilities() error {
abilities = append(abilities, ability)
}
}
if len(abilities) == 0 {
return nil
}
for _, chunk := range lo.Chunk(abilities, 50) {
err := DB.Create(&chunk).Error
if err != nil {
return err
}
}
return nil
return DB.Create(&abilities).Error
}
func (channel *Channel) DeleteAbilities() error {

View File

@@ -25,6 +25,9 @@ var token2UserId = make(map[string]int)
var token2UserIdLock sync.RWMutex
func cacheSetToken(token *Token) error {
if !common.RedisEnabled {
return token.SelectUpdate()
}
jsonBytes, err := json.Marshal(token)
if err != nil {
return err
@@ -165,11 +168,7 @@ func CacheUpdateUserQuota(id int) error {
if err != nil {
return err
}
return cacheSetUserQuota(id, quota)
}
func cacheSetUserQuota(id int, quota int) error {
err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
return err
}
@@ -205,6 +204,30 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
if !common.RedisEnabled {
return IsLinuxDoEnabled(userId)
}
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if linuxDoEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set linuxdo enabled error: " + err.Error())
}
return linuxDoEnabled, err
}
var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex
@@ -273,7 +296,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model, retry)
return GetRandomSatisfiedChannel(group, model)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()

View File

@@ -25,10 +25,8 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
//MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
}
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
@@ -155,13 +153,6 @@ func (channel *Channel) GetModelMapping() string {
return *channel.ModelMapping
}
func (channel *Channel) GetStatusCodeMapping() string {
if channel.StatusCodeMapping == nil {
return ""
}
return *channel.StatusCodeMapping
}
func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error

View File

@@ -24,7 +24,6 @@ type Log struct {
IsStream bool `json:"is_stream" gorm:"default:false"`
ChannelId int `json:"channel" gorm:"index"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Other string `json:"other"`
}
const (
@@ -58,13 +57,12 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool, other map[string]interface{}) {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
username, _ := CacheGetUsername(userId)
otherStr := common.MapToJsonStr(other)
log := &Log{
UserId: userId,
Username: username,
@@ -80,7 +78,6 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
TokenId: tokenId,
UseTime: useTimeSeconds,
IsStream: isStream,
Other: otherStr,
}
err := DB.Create(log).Error
if err != nil {
@@ -93,7 +90,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB
@@ -118,11 +115,17 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err
return logs, total, err
}
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB.Where("user_id = ?", userId)
@@ -141,8 +144,14 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
return logs, err
return logs, total, err
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {

View File

@@ -31,10 +31,12 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
@@ -44,7 +46,6 @@ func InitOptionMap() {
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
common.OptionMap["SMTPServer"] = ""
common.OptionMap["SMTPFrom"] = ""
@@ -59,15 +60,18 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = ""
common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = ""
common.OptionMap["EpayKey"] = ""
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
common.OptionMap["StripeApiSecret"] = common.StripeApiSecret
common.OptionMap["StripeWebhookSecret"] = common.StripeWebhookSecret
common.OptionMap["StripePriceId"] = common.StripePriceId
common.OptionMap["PaymentEnabled"] = strconv.FormatBool(common.PaymentEnabled)
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["LinuxDoClientId"] = ""
common.OptionMap["LinuxDoClientSecret"] = ""
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
common.OptionMap["TelegramBotToken"] = ""
common.OptionMap["TelegramBotName"] = ""
common.OptionMap["WeChatServerAddress"] = ""
@@ -83,7 +87,6 @@ func InitOptionMap() {
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -93,9 +96,6 @@ func InitOptionMap() {
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
@@ -169,6 +169,8 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue
case "LinuxDoOAuthEnabled":
common.LinuxDoOAuthEnabled = boolValue
case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue
case "TelegramOAuthEnabled":
@@ -177,10 +179,10 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled":
common.RegisterEnabled = boolValue
case "UserSelfDeletionEnabled":
common.UserSelfDeletionEnabled = boolValue
case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue
case "EmailAliasRestrictionEnabled":
common.EmailAliasRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
@@ -199,12 +201,6 @@ func updateOptionMap(key string, value string) (err error) {
common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled":
constant.MjNotifyEnabled = boolValue
case "MjAccountFilterEnabled":
constant.MjAccountFilterEnabled = boolValue
case "MjModeClearEnabled":
constant.MjModeClearEnabled = boolValue
case "MjForwardUrlEnabled":
constant.MjForwardUrlEnabled = boolValue
case "CheckSensitiveEnabled":
constant.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
@@ -233,24 +229,30 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPToken = value
case "ServerAddress":
common.ServerAddress = value
case "PayAddress":
constant.PayAddress = value
case "CustomCallbackAddress":
constant.CustomCallbackAddress = value
case "EpayId":
constant.EpayId = value
case "EpayKey":
constant.EpayKey = value
case "Price":
constant.Price, _ = strconv.ParseFloat(value, 64)
case "StripeApiSecret":
common.StripeApiSecret = value
case "StripeWebhookSecret":
common.StripeWebhookSecret = value
case "StripePriceId":
common.StripePriceId = value
case "PaymentEnabled":
common.PaymentEnabled, _ = strconv.ParseBool(value)
case "StripeUnitPrice":
common.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
case "MinTopUp":
constant.MinTopUp, _ = strconv.Atoi(value)
common.MinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":
common.GitHubClientId = value
case "GitHubClientSecret":
common.GitHubClientSecret = value
case "LinuxDoClientId":
common.LinuxDoClientId = value
case "LinuxDoClientSecret":
common.LinuxDoClientSecret = value
case "LinuxDoMinLevel":
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
case "Footer":
common.Footer = value
case "SystemName":
@@ -291,8 +293,6 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
err = common.UpdateModelPriceByJSONString(value)
case "TopUpLink":

View File

@@ -102,11 +102,6 @@ func GetTokenById(id int) (*Token, error) {
token := Token{Id: id}
var err error = nil
err = DB.First(&token, "id = ?", id).Error
if err != nil {
if common.RedisEnabled {
go cacheSetToken(&token)
}
}
return &token, err
}

View File

@@ -1,13 +1,21 @@
package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
)
type TopUp struct {
Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"`
Amount int `json:"amount"`
Money float64 `json:"money"`
TradeNo string `json:"trade_no"`
CreateTime int64 `json:"create_time"`
Status string `json:"status"`
Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"`
Amount int `json:"amount"`
Money float64 `json:"money"`
TradeNo string `json:"trade_no" gorm:"unique"`
CreateTime int64 `json:"create_time"`
CompleteTime int64 `json:"complete_time"`
Status string `json:"status"`
}
func (topUp *TopUp) Insert() error {
@@ -41,3 +49,51 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
}
return topUp
}
func Recharge(referenceId string, customerId string) (err error) {
if referenceId == "" {
return errors.New("未提供支付单号")
}
var quota float64
topUp := &TopUp{}
refCol := "`trade_no`"
if common.UsingPostgreSQL {
refCol = `"trade_no"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
if err != nil {
return errors.New("充值订单不存在")
}
if topUp.Status != common.TopUpStatusPending {
return errors.New("充值订单状态错误")
}
topUp.CompleteTime = common.GetTimestamp()
topUp.Status = common.TopUpStatusSuccess
err = tx.Save(topUp).Error
if err != nil {
return err
}
quota = topUp.Money * common.QuotaPerUnit
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
if err != nil {
return err
}
return nil
})
if err != nil {
return errors.New("充值失败," + err.Error())
}
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%d", common.LogQuotaF(quota), topUp.Amount))
return nil
}

View File

@@ -22,6 +22,8 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
@@ -35,6 +37,7 @@ type User struct {
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
StripeCustomer string `json:"stripe_customer" gorm:"column:stripe_customer;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
@@ -64,7 +67,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
func GetMaxUserId() int {
var user User
DB.Last(&user)
DB.Unscoped().Last(&user)
return user.Id
}
@@ -73,34 +76,25 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
return users, err
}
func SearchUsers(keyword string, group string) ([]*User, error) {
func SearchUsers(keyword string) ([]*User, error) {
var users []*User
var err error
// 尝试将关键字转换为整数ID
keywordInt, err := strconv.Atoi(keyword)
if err == nil {
// 如果转换成功按照ID和可选的组别搜索用户
query := DB.Unscoped().Omit("password").Where("`id` = ?", keywordInt)
if group != "" {
query = query.Where("`group` = ?", group) // 使用反引号包围group
}
err = query.Find(&users).Error
// 如果转换成功按照ID搜索用户
err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error
if err != nil || len(users) > 0 {
// 如果依据ID找到用户或者发生错误返回结果或错误
return users, err
}
}
err = nil
query := DB.Unscoped().Omit("password")
likeCondition := "`username` LIKE ? OR `email` LIKE ? OR `display_name` LIKE ?"
if group != "" {
query = query.Where("("+likeCondition+") AND `group` = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
err = query.Find(&users).Error
// 如果ID转换失败或者没有找到用户依据其他字段进行模糊搜索
err = DB.Unscoped().Omit("password").
Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").
Find(&users).Error
return users, err
}
@@ -119,6 +113,20 @@ func GetUserById(id int, selectAll bool) (*User, error) {
return &user, err
}
func GetUserByIdUnscoped(id int, selectAll bool) (*User, error) {
if id == 0 {
return nil, errors.New("id 为空!")
}
user := User{Id: id}
var err error = nil
if selectAll {
err = DB.Unscoped().First(&user, "id = ?", id).Error
} else {
err = DB.Unscoped().Omit("password").First(&user, "id = ?", id).Error
}
return &user, err
}
func GetUserIdByAffCode(affCode string) (int, error) {
if affCode == "" {
return 0, errors.New("affCode 为空!")
@@ -244,7 +252,7 @@ func (user *User) Update(updatePassword bool) error {
return err
}
func (user *User) Edit(updatePassword bool) error {
func (user *User) UpdateAll(updatePassword bool) error {
var err error
if updatePassword {
user.Password, err = common.Password2Hash(user.Password)
@@ -253,17 +261,8 @@ func (user *User) Edit(updatePassword bool) error {
}
}
newUser := *user
updates := map[string]interface{}{
"username": newUser.Username,
"display_name": newUser.DisplayName,
"group": newUser.Group,
"quota": newUser.Quota,
}
if updatePassword {
updates["password"] = newUser.Password
}
DB.First(&user, user.Id)
err = DB.Model(user).Updates(updates).Error
err = DB.Model(user).Select("*").Updates(newUser).Error
if err == nil {
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
@@ -330,6 +329,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByLinuxDoId() error {
if user.LinuxDoId == "" {
return errors.New("LINUX DO id 为空!")
}
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -369,6 +376,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
return DB.Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
@@ -414,6 +425,18 @@ func IsUserEnabled(userId int) (bool, error) {
return user.Status == common.UserStatusEnabled, nil
}
func IsLinuxDoEnabled(userId int) (bool, error) {
if userId == 0 {
return false, errors.New("user id is empty")
}
var user User
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
if err != nil {
return false, err
}
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
}
func ValidateAccessToken(token string) (user *User) {
if token == "" {
return nil
@@ -428,11 +451,6 @@ func ValidateAccessToken(token string) (user *User) {
func GetUserQuota(id int) (quota int, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
if err != nil {
if common.RedisEnabled {
go cacheSetUserQuota(id, quota)
}
}
return quota, err
}

View File

@@ -6,5 +6,3 @@ var ModelList = []string{
"embedding_s1_v1",
"semantic_similarity_s1_v1",
}
var ChannelName = "ai360"

View File

@@ -136,7 +136,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(aliResponse.Output.Text)
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
@@ -199,7 +199,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {

View File

@@ -1,79 +0,0 @@
package aws
import (
"errors"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"strings"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
} else {
a.RequestMode = RequestModeCompletion
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
var claudeReq *claude.ClaudeRequest
var err error
if a.RequestMode == RequestModeCompletion {
claudeReq = claude.RequestOpenAI2ClaudeComplete(*request)
} else {
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
}
c.Set("request_model", request.Model)
c.Set("converted_request", claudeReq)
return claudeReq, err
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = awsStreamHandler(c, info, a.RequestMode)
} else {
err, usage = awsHandler(c, info, a.RequestMode)
}
return
}
func (a *Adaptor) GetModelList() (models []string) {
for n := range awsModelIDMap {
models = append(models, n)
}
return
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -1,12 +0,0 @@
package aws
var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
}
var ChannelName = "aws"

View File

@@ -1,15 +0,0 @@
package aws
import "one-api/relay/channel/claude"
type AwsClaudeRequest struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
System string `json:"system"`
Messages []claude.ClaudeMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
}

View File

@@ -1,213 +0,0 @@
package aws
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"io"
"net/http"
"one-api/common"
relaymodel "one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
)
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
awsSecret := strings.Split(info.ApiKey, "|")
if len(awsSecret) != 3 {
return nil, errors.New("invalid aws secret key")
}
ak := awsSecret[0]
sk := awsSecret[1]
region := awsSecret[2]
client := bedrockruntime.New(bedrockruntime.Options{
Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
})
return client, nil
}
func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
return &relaymodel.OpenAIErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.OpenAIError{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID, nil
}
return "", errors.Errorf("model %s not found", requestModel)
}
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(c.GetString("request_model"))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get("converted_request")
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*claude.ClaudeRequest)
awsClaudeReq := &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
}
claudeResponse := new(claude.ClaudeResponse)
err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
}
openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
usage := relaymodel.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
}
awsModelId, err := awsModelID(c.GetString("request_model"))
if err != nil {
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get("converted_request")
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*claude.ClaudeRequest)
awsClaudeReq := &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31",
}
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
return wrapErr(errors.Wrap(err, "copy request")), nil
}
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return wrapErr(errors.Wrap(err, "marshal request")), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
}
stream := awsResp.GetStream()
defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
var id string
var model string
createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
claudeResp := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return false
}
response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
if claudeUsage != nil {
usage.PromptTokens += claudeUsage.InputTokens
usage.CompletionTokens += claudeUsage.OutputTokens
}
if response == nil {
return true
}
if response.Id != "" {
id = response.Id
}
if response.Model != "" {
model = response.Model
}
response.Created = createdTime
response.Id = id
response.Model = model
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
default:
fmt.Println("union is nil or unknown type")
return false
}
})
return nil, &usage
}

View File

@@ -57,7 +57,7 @@ func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(baiduResponse.Result)
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &relaycommon.StopFinishReason
}

View File

@@ -53,9 +53,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeCompletion {
return RequestOpenAI2ClaudeComplete(*request), nil
return requestOpenAI2ClaudeComplete(*request), nil
} else {
return RequestOpenAI2ClaudeMessage(*request)
return requestOpenAI2ClaudeMessage(*request)
}
}

View File

@@ -28,8 +28,8 @@ type ClaudeRequest struct {
Prompt string `json:"prompt,omitempty"`
System string `json:"system,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`

View File

@@ -20,24 +20,25 @@ func stopReasonClaude2OpenAI(reason string) string {
case "end_turn":
return "stop"
case "max_tokens":
return "max_tokens"
return "length"
default:
return reason
}
}
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
Model: textRequest.Model,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 4096
claudeRequest.MaxTokensToSample = 1000000
}
prompt := ""
for _, message := range textRequest.Messages {
@@ -56,7 +57,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
return &claudeRequest
}
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
@@ -69,52 +70,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message
for i, message := range textRequest.Messages {
//if message.Role == "system" {
// if i != 0 {
// message.Role = "user"
// }
//}
if message.Role == "" {
textRequest.Messages[i].Role = "user"
}
fmtMessage := dto.Message{
Role: message.Role,
Content: message.Content,
}
if lastMessage != nil && lastMessage.Role == message.Role {
if lastMessage.IsStringContent() && message.IsStringContent() {
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
fmtMessage.Content = content
// delete last message
formatMessages = formatMessages[:len(formatMessages)-1]
}
}
if fmtMessage.Content == nil {
content, _ := json.Marshal("...")
fmtMessage.Content = content
}
formatMessages = append(formatMessages, fmtMessage)
lastMessage = &textRequest.Messages[i]
}
claudeMessages := make([]ClaudeMessage, 0)
for _, message := range formatMessages {
for _, message := range textRequest.Messages {
if message.Role == "system" {
if message.IsStringContent() {
claudeRequest.System = message.StringContent()
} else {
contents := message.ParseContent()
content := ""
for _, ctx := range contents {
if ctx.Type == "text" {
content += ctx.Text
}
}
claudeRequest.System = content
}
claudeRequest.System = message.StringContent()
} else {
claudeMessage := ClaudeMessage{
Role: message.Role,
@@ -159,10 +118,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
}
claudeRequest.Prompt = ""
claudeRequest.Messages = claudeMessages
return &claudeRequest, nil
}
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
var response dto.ChatCompletionsStreamResponse
var claudeUsage *ClaudeUsage
response.Object = "chat.completion.chunk"
@@ -170,7 +130,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion)
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
@@ -180,34 +140,25 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("")
choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" {
return nil, nil
} else if claudeResponse.Type == "content_block_delta" {
choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
choice.Delta.Content = claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {
return nil, nil
} else {
return nil, nil
}
}
if claudeUsage == nil {
claudeUsage = &ClaudeUsage{}
}
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
}
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
choices := make([]dto.OpenAITextResponseChoice, 0)
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -291,10 +242,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
return true
}
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
return true
}
response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
@@ -369,7 +317,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil

View File

@@ -1,52 +0,0 @@
package cohere
import (
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -1,7 +0,0 @@
package cohere
var ModelList = []string{
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
}
var ChannelName = "cohere"

View File

@@ -1,44 +0,0 @@
package cohere
type CohereRequest struct {
Model string `json:"model"`
ChatHistory []ChatHistory `json:"chat_history"`
Message string `json:"message"`
Stream bool `json:"stream"`
MaxTokens int64 `json:"max_tokens"`
}
type ChatHistory struct {
Role string `json:"role"`
Message string `json:"message"`
}
type CohereResponse struct {
IsFinished bool `json:"is_finished"`
EventType string `json:"event_type"`
Text string `json:"text,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Response *CohereResponseResult `json:"response"`
}
type CohereResponseResult struct {
ResponseId string `json:"response_id"`
FinishReason string `json:"finish_reason,omitempty"`
Text string `json:"text"`
Meta CohereMeta `json:"meta"`
}
type CohereMeta struct {
//Tokens CohereTokens `json:"tokens"`
BilledUnits CohereBilledUnits `json:"billed_units"`
}
type CohereBilledUnits struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type CohereTokens struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

View File

@@ -1,189 +0,0 @@
package cohere
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
"strings"
)
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
cohereReq := CohereRequest{
Model: textRequest.Model,
ChatHistory: []ChatHistory{},
Message: "",
Stream: textRequest.Stream,
MaxTokens: textRequest.GetMaxTokens(),
}
if cohereReq.MaxTokens == 0 {
cohereReq.MaxTokens = 4000
}
for _, msg := range textRequest.Messages {
if msg.Role == "user" {
cohereReq.Message = msg.StringContent()
} else {
var role string
if msg.Role == "assistant" {
role = "CHATBOT"
} else if msg.Role == "system" {
role = "SYSTEM"
} else {
role = "USER"
}
cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
Role: role,
Message: msg.StringContent(),
})
}
}
return &cohereReq
}
func stopReasonCohere2OpenAI(reason string) string {
switch reason {
case "COMPLETE":
return "stop"
case "MAX_TOKENS":
return "max_tokens"
default:
return reason
}
}
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
data = strings.TrimSuffix(data, "\r")
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResp dto.ChatCompletionsStreamResponse
openaiResp.Id = responseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion.chunk"
openaiResp.Model = modelName
if cohereResp.IsFinished {
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
{
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
Index: 0,
FinishReason: &finishReason,
},
}
if cohereResp.Response != nil {
usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
}
} else {
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
{
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
Role: "assistant",
Content: &cohereResp.Text,
},
Index: 0,
},
}
responseText += cohereResp.Text
}
jsonStr, err := json.Marshal(openaiResp)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
}
return nil, usage
}
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
createdTime := common.GetTimestamp()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
var openaiResp dto.TextResponse
openaiResp.Id = cohereResp.ResponseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion"
openaiResp.Model = modelName
openaiResp.Usage = usage
content, _ := json.Marshal(cohereResp.Text)
openaiResp.Choices = []dto.OpenAITextResponseChoice{
{
Index: 0,
Message: dto.Message{Content: content, Role: "assistant"},
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
},
}
jsonResponse, err := json.Marshal(openaiResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -18,28 +18,16 @@ type Adaptor struct {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
// 定义一个映射,存储模型名称和对应的版本
var modelVersionMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-ultra": "v1beta",
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
} else {
version = "v1"
}
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
version := "v1"
if info.ApiVersion != "" {
version = info.ApiVersion
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {

View File

@@ -5,8 +5,8 @@ const (
)
var ModelList = []string{
"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-ultra",
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001",
"gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-pro-vision", "gemini-1.0-pro-vision-001",
}
var ChannelName = "google gemini"

View File

@@ -151,7 +151,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(geminiResponse.GetResponseText())
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &relaycommon.StopFinishReason
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
@@ -203,7 +203,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(dummy.Content)
choice.Delta.Content = dummy.Content
response := dto.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",

View File

@@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {

View File

@@ -1,7 +1,5 @@
package ollama
var ModelList = []string{
"llama3-7b",
}
var ModelList []string
var ChannelName = "ollama"

View File

@@ -9,7 +9,6 @@ import (
"net/http"
"one-api/dto"
"one-api/service"
"strings"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
@@ -42,7 +41,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Prompt: strings.Join(request.ParseInput(), " "),
Prompt: request.Input,
}
}

View File

@@ -72,10 +72,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode)
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -6,7 +6,7 @@ var ModelList = []string{
"gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",

View File

@@ -16,10 +16,9 @@ import (
"time"
)
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) {
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
//checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@@ -50,7 +49,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
common.SafeSendString(dataChan, data)
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data)
@@ -68,32 +67,14 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
responseTextBuilder.WriteString(choice.Delta.Content)
}
}
}
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
responseTextBuilder.WriteString(choice.Delta.Content)
}
}
}
@@ -123,7 +104,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
// wait data out
time.Sleep(2 * time.Second)
}
common.SafeSendBool(stopChan, true)
common.SafeSend(stopChan, true)
}()
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
@@ -142,10 +123,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
wg.Wait()
return nil, responseTextBuilder.String(), toolCount
return nil, responseTextBuilder.String()
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

View File

@@ -61,7 +61,7 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.SetContentString(palmResponse.Candidates[0].Content)
choice.Delta.Content = palmResponse.Candidates[0].Content
}
choice.FinishReason = &relaycommon.StopFinishReason
var response dto.ChatCompletionsStreamResponse

View File

@@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)

View File

@@ -86,7 +86,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
}
if len(TencentResponse.Choices) > 0 {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content)
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &relaycommon.StopFinishReason
}
@@ -138,7 +138,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.GetContentString()
responseText += response.Choices[0].Delta.Content
}
jsonResponse, err := json.Marshal(response)
if err != nil {

View File

@@ -87,7 +87,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCo
}
}
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content)
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &relaycommon.StopFinishReason
}
@@ -179,13 +179,7 @@ func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId s
case stop = <-stopChan:
}
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := responseXunfei2OpenAI(&xunfeiResponse)

View File

@@ -126,7 +126,7 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(zhipuResponse)
choice.Delta.Content = zhipuResponse
response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
@@ -138,7 +138,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStream
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString("")
choice.Delta.Content = ""
choice.FinishReason = &relaycommon.StopFinishReason
response := dto.ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId,

View File

@@ -47,10 +47,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -18,25 +18,13 @@ const (
APITypeZhipu_v4
APITypeOllama
APITypePerplexity
APITypeAws
APITypeCohere
APITypeDummy // this one is only for count, do not add any channel after this
)
func ChannelType2APIType(channelType int) int {
apiType := -1
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeOpenAI:
apiType = APITypeOpenAI
case common.ChannelTypeAzure:
apiType = APITypeOpenAI
case common.ChannelTypeMoonshot:
apiType = APITypeOpenAI
case common.ChannelTypeLingYiWanWu:
apiType = APITypeOpenAI
case common.ChannelType360:
apiType = APITypeOpenAI
case common.ChannelTypeAnthropic:
apiType = APITypeAnthropic
case common.ChannelTypeBaidu:
@@ -61,10 +49,6 @@ func ChannelType2APIType(channelType int) int {
apiType = APITypeOllama
case common.ChannelTypePerplexity:
apiType = APITypePerplexity
case common.ChannelTypeAws:
apiType = APITypeAws
case common.ChannelTypeCohere:
apiType = APITypeCohere
}
return apiType
}

View File

@@ -20,6 +20,15 @@ import (
"time"
)
var availableVoices = []string{
"alloy",
"echo",
"fable",
"onyx",
"nova",
"shimmer",
}
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
@@ -50,6 +59,9 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if audioRequest.Voice == "" {
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
}
if !common.StringsContains(availableVoices, audioRequest.Voice) {
return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
}
}
var err error
promptTokens := 0
@@ -88,22 +100,6 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
}
succeed := false
defer func() {
if succeed {
return
}
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func() {
go func() {
// negative means add quota back for token & user
returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota)
}()
}()
}
}()
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
@@ -167,7 +163,6 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if resp.StatusCode != http.StatusOK {
return relaycommon.RelayErrorHandler(resp)
}
succeed = true
var audioResponse dto.AudioResponse
@@ -196,10 +191,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

View File

@@ -34,7 +34,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-3"
imageRequest.Model = "dall-e-2"
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
@@ -186,15 +186,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
}
if quota != 0 {
tokenName := c.GetString("token_name")
quality := "normal"
if imageRequest.Quality == "hd" {
quality = "hd"
}
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelRatio, groupRatio, imageRequest.Size, quality)
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

View File

@@ -110,13 +110,11 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
if originTask.ImageUrl != "" {
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
} else {
midjourneyTask.ImageUrl = originTask.ImageUrl
}
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
@@ -202,10 +200,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
@@ -501,10 +496,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

View File

@@ -154,28 +154,20 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
requestBody = bytes.NewBuffer(jsonData)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp != nil {
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
return service.RelayErrorHandler(resp)
}
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
@@ -189,7 +181,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
checkSensitive := constant.ShouldCheckPromptSensitive()
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
case relayconstant.RelayModeCompletions:
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
case relayconstant.RelayModeModerations:
@@ -264,10 +256,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
completionTokens := usage.CompletionTokens
tokenName := ctx.GetString("token_name")
completionRatio := common.GetCompletionRatio(textRequest.Model)
quota := 0
if modelPrice == -1 {
completionRatio := common.GetCompletionRatio(textRequest.Model)
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
@@ -279,7 +271,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
totalTokens := promptTokens + completionTokens
var logContent string
if modelPrice == -1 {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
}
@@ -315,12 +307,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
}
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
other["completion_ratio"] = completionRatio
other["model_price"] = modelPrice
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream)
//if quota != 0 {
//

View File

@@ -3,10 +3,8 @@ package relay
import (
"one-api/relay/channel"
"one-api/relay/channel/ali"
"one-api/relay/channel/aws"
"one-api/relay/channel/baidu"
"one-api/relay/channel/claude"
"one-api/relay/channel/cohere"
"one-api/relay/channel/gemini"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
@@ -47,10 +45,6 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &ollama.Adaptor{}
case constant.APITypePerplexity:
return &perplexity.Adaptor{}
case constant.APITypeAws:
return &aws.Adaptor{}
case constant.APITypeCohere:
return &cohere.Adaptor{}
}
return nil
}

View File

@@ -14,16 +14,16 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.Use(middleware.GlobalAPIRateLimit())
{
apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout)
//apiRouter.GET("/midjourney", controller.GetMidjourney)
apiRouter.GET("/midjourney", controller.GetMidjourney)
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxDoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
@@ -31,13 +31,14 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
userRoute := apiRouter.Group("/user")
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout)
userRoute.GET("/epay/notify", controller.EpayNotify)
selfRoute := userRoute.Group("/")
selfRoute.Use(middleware.UserAuth())
@@ -48,8 +49,8 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.DELETE("/self", controller.DeleteSelf)
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp)
selfRoute.POST("/pay", controller.RequestEpay)
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestPayLink)
selfRoute.POST("/amount", controller.RequestAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
}
@@ -63,7 +64,7 @@ func SetApiRouter(router *gin.Engine) {
adminRoute.POST("/", controller.CreateUser)
adminRoute.POST("/manage", controller.ManageUser)
adminRoute.PUT("/", controller.UpdateUser)
adminRoute.DELETE("/:id", controller.DeleteUser)
adminRoute.DELETE("/:id", controller.HardDeleteUser)
}
}
optionRoute := apiRouter.Group("/option")

View File

@@ -57,13 +57,11 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
} else if strings.HasPrefix(err.Message, "You exceeded your current quota") {
return true
}
return false
}
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool {
func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
@@ -73,8 +71,5 @@ func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status in
if openAIErr != nil {
return false
}
if status != common.ChannelStatusAutoDisabled {
return false
}
return true
}

View File

@@ -1,13 +0,0 @@
package service
import (
"one-api/common"
"one-api/constant"
)
func GetCallbackAddress() string {
if constant.CustomCallbackAddress == "" {
return common.ServerAddress
}
return constant.CustomCallbackAddress
}

View File

@@ -86,22 +86,3 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
}
return
}
func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMappingStr string) {
if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
return
}
statusCodeMapping := make(map[string]string)
err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
if err != nil {
return
}
if openaiErr.StatusCode == http.StatusOK {
return
}
codeStr := strconv.Itoa(openaiErr.StatusCode)
if _, ok := statusCodeMapping[codeStr]; ok {
intCode, _ := strconv.Atoi(statusCodeMapping[codeStr])
openaiErr.StatusCode = intCode
}
}

View File

@@ -165,24 +165,13 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
if !constant.MjAccountFilterEnabled {
delete(mapResult, "accountFilter")
}
delete(mapResult, "accountFilter")
if !constant.MjNotifyEnabled {
delete(mapResult, "notifyHook")
}
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
// make new request with mapResult
}
if constant.MjModeClearEnabled {
if prompt, ok := mapResult["prompt"].(string); ok {
prompt = strings.Replace(prompt, "--fast", "", -1)
prompt = strings.Replace(prompt, "--relax", "", -1)
prompt = strings.Replace(prompt, "--turbo", "", -1)
mapResult["prompt"] = prompt
}
}
reqBody, err := json.Marshal(mapResult)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err

View File

@@ -116,41 +116,6 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
return tiles*170 + 85, nil
}
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
tkm := 0
msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive)
if err != nil {
return 0, err, b
}
tkm += msgTokens
if request.Tools != nil {
toolsData, _ := json.Marshal(request.Tools)
var openaiTools []dto.OpenAITools
err := json.Unmarshal(toolsData, &openaiTools)
if err != nil {
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false
}
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
if tool.Function.Description != "" {
countStr += tool.Function.Description
}
if tool.Function.Parameters != nil {
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens, err, _ := CountTokenInput(countStr, model, false)
if err != nil {
return 0, err, false
}
tkm += 8
tkm += toolTokens
}
return tkm, nil, false
}
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
@@ -173,31 +138,48 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
if len(message.Content) > 0 {
if message.IsStringContent() {
stringContent := message.StringContent()
if checkSensitive {
contains, words := SensitiveWordContains(stringContent)
if contains {
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
return 0, err, true
var arrayContent []dto.MediaMessage
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
var stringContent string
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
return 0, err, false
} else {
if checkSensitive {
contains, words := SensitiveWordContains(stringContent)
if contains {
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
return 0, err, true
}
}
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
} else {
var err error
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == "image_url" {
var imageTokenNum int
if model == "glm-4v" {
imageTokenNum = 1047
} else {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err = getImageToken(&imageUrl)
if str, ok := m.ImageUrl.(string); ok {
imageTokenNum, err = getImageToken(&dto.MessageImageUrl{Url: str, Detail: "auto"})
} else {
imageUrlMap := m.ImageUrl.(map[string]interface{})
detail, ok := imageUrlMap["detail"]
if ok {
imageUrlMap["detail"] = detail.(string)
} else {
imageUrlMap["detail"] = "auto"
}
imageUrl := dto.MessageImageUrl{
Url: imageUrlMap["url"].(string),
Detail: imageUrlMap["detail"].(string),
}
imageTokenNum, err = getImageToken(&imageUrl)
}
if err != nil {
return 0, err, false
}
@@ -229,23 +211,6 @@ func CountTokenInput(input any, model string, check bool) (int, error, bool) {
return CountTokenInput(fmt.Sprintf("%v", input), model, check)
}
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
tokens := 0
for _, message := range messages {
tkm, _, _ := CountTokenInput(message.Delta.GetContentString(), model, false)
tokens += tkm
if message.Delta.ToolCalls != nil {
for _, tool := range message.Delta.ToolCalls {
tkm, _, _ := CountTokenInput(tool.Function.Name, model, false)
tokens += tkm
tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false)
tokens += tkm
}
}
}
return tokens
}
func CountAudioToken(text string, model string, check bool) (int, error, bool) {
if strings.HasPrefix(model, "tts") {
contains, words := SensitiveWordContains(text)

5
web/.gitignore vendored
View File

@@ -10,6 +10,7 @@
# production
/build
/dist
# misc
.DS_Store
@@ -21,6 +22,4 @@
npm-debug.log*
yarn-debug.log*
yarn-error.log*
.idea
package-lock.json
yarn.lock
.idea/

View File

@@ -11,6 +11,7 @@ import EditUser from './pages/User/EditUser';
import { getLogo, getSystemName } from './helpers';
import PasswordResetForm from './components/PasswordResetForm';
import GitHubOAuth from './components/GitHubOAuth';
import LinuxDoOAuth from './components/LinuxDoOAuth';
import PasswordResetConfirm from './components/PasswordResetConfirm';
import { UserContext } from './context/User';
import Channel from './pages/Channel';
@@ -171,6 +172,14 @@ function App() {
</Suspense>
}
/>
<Route
path='/oauth/linuxdo'
element={
<Suspense fallback={<Loading></Loading>}>
<LinuxDoOAuth />
</Suspense>
}
/>
<Route
path='/setting'
element={

View File

@@ -31,7 +31,6 @@ import {
} from '@douyinfe/semi-ui';
import EditChannel from '../pages/Channel/EditChannel';
import { IconTreeTriangleDown } from '@douyinfe/semi-icons';
import { loadChannelModels } from './utils.js';
function renderTimestamp(timestamp) {
return <>{timestamp2string(timestamp)}</>;
@@ -255,19 +254,6 @@ const ChannelsTable = () => {
>
编辑
</Button>
<Popconfirm
title='确定是否要复制此渠道?'
content='复制渠道的所有信息'
okType={'danger'}
position={'left'}
onConfirm={async () => {
copySelectedChannel(record.id);
}}
>
<Button theme='light' type='primary' style={{ marginRight: 1 }}>
复制
</Button>
</Popconfirm>
</div>
),
},
@@ -354,33 +340,6 @@ const ChannelsTable = () => {
setLoading(false);
};
const copySelectedChannel = async (id) => {
const channelToCopy = channels.find(
(channel) => String(channel.id) === String(id),
);
console.log(channelToCopy);
channelToCopy.name += '_复制';
channelToCopy.created_time = null;
channelToCopy.balance = 0;
channelToCopy.used_quota = 0;
if (!channelToCopy) {
showError('渠道未找到,请刷新页面后重试。');
return;
}
try {
const newChannel = { ...channelToCopy, id: undefined };
const response = await API.post('/api/channel/', newChannel);
if (response.data.success) {
showSuccess('渠道复制成功');
await refresh();
} else {
showError(response.data.message);
}
} catch (error) {
showError('渠道复制失败: ' + error.message);
}
};
const refresh = async () => {
await loadChannels(activePage - 1, pageSize, idSort);
};
@@ -398,7 +357,6 @@ const ChannelsTable = () => {
showError(reason);
});
fetchGroups().then();
loadChannelModels().then();
}, []);
const manageChannel = async (id, action, record, value) => {

View File

@@ -14,9 +14,14 @@ const GitHubOAuth = () => {
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
let aff = localStorage.getItem('aff');
const res = await API.get(
`/api/oauth/github?code=${code}&state=${state}&aff=${aff}`,
);
const { success, message, data } = res.data;
if (success) {
localStorage.removeItem('aff');
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/setting');
@@ -41,6 +46,14 @@ const GitHubOAuth = () => {
};
useEffect(() => {
let error = searchParams.get('error');
if (error) {
let errorDescription = searchParams.get('error_description');
showError(`授权错误:${error}: ${errorDescription}`);
navigate('/setting');
return;
}
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();

View File

@@ -1,7 +1,6 @@
import React, { useContext, useEffect, useState } from 'react';
import { Link, useNavigate } from 'react-router-dom';
import { UserContext } from '../context/User';
import { useSetTheme, useTheme } from '../context/Theme';
import { API, getLogo, getSystemName, showSuccess } from '../helpers';
import '../index.css';
@@ -35,8 +34,10 @@ const HeaderBar = () => {
let navigate = useNavigate();
const [showSidebar, setShowSidebar] = useState(false);
const [dark, setDark] = useState(false);
const systemName = getSystemName();
const logo = getLogo();
var themeMode = localStorage.getItem('theme-mode');
const currentDate = new Date();
// enable fireworks on new year(1.1 and 2.9-2.24)
const isNewYear =
@@ -65,19 +66,26 @@ const HeaderBar = () => {
}, 3000);
};
const theme = useTheme();
const setTheme = useSetTheme();
useEffect(() => {
if (theme === 'dark') {
document.body.setAttribute('theme-mode', 'dark');
if (themeMode === 'dark') {
switchMode(true);
}
if (isNewYear) {
console.log('Happy New Year!');
}
}, []);
const switchMode = (model) => {
const body = document.body;
if (!model) {
body.removeAttribute('theme-mode');
localStorage.setItem('theme-mode', 'light');
} else {
body.setAttribute('theme-mode', 'dark');
localStorage.setItem('theme-mode', 'dark');
}
setDark(model);
};
return (
<>
<Layout>
@@ -124,11 +132,9 @@ const HeaderBar = () => {
<Switch
checkedText='🌞'
size={'large'}
checked={theme === 'dark'}
checked={dark}
uncheckedText='🌙'
onChange={(checked) => {
setTheme(checked);
}}
onChange={switchMode}
/>
{userState.user ? (
<>

View File

@@ -0,0 +1,31 @@
import React from 'react';
import { Icon } from '@douyinfe/semi-ui';
const LinuxDoIcon = (props) => {
function CustomIcon() {
return (
<svg
className='icon'
viewBox='0 0 24 24'
version='1.1'
xmlns='http://www.w3.org/2000/svg'
width='16'
height='16'
{...props}
>
<path
d='M19.7,17.6c-0.1-0.2-0.2-0.4-0.2-0.6c0-0.4-0.2-0.7-0.5-1c-0.1-0.1-0.3-0.2-0.4-0.2c0.6-1.8-0.3-3.6-1.3-4.9c0,0,0,0,0,0c-0.8-1.2-2-2.1-1.9-3.7c0-1.9,0.2-5.4-3.3-5.1C8.5,2.3,9.5,6,9.4,7.3c0,1.1-0.5,2.2-1.3,3.1c-0.2,0.2-0.4,0.5-0.5,0.7c-1,1.2-1.5,2.8-1.5,4.3c-0.2,0.2-0.4,0.4-0.5,0.6c-0.1,0.1-0.2,0.2-0.2,0.3c-0.1,0.1-0.3,0.2-0.5,0.3c-0.4,0.1-0.7,0.3-0.9,0.7c-0.1,0.3-0.2,0.7-0.1,1.1c0.1,0.2,0.1,0.4,0,0.7c-0.2,0.4-0.2,0.9,0,1.4c0.3,0.4,0.8,0.5,1.5,0.6c0.5,0,1.1,0.2,1.6,0.4l0,0c0.5,0.3,1.1,0.5,1.7,0.5c0.3,0,0.7-0.1,1-0.2c0.3-0.2,0.5-0.4,0.6-0.7c0.4,0,1-0.2,1.7-0.2c0.6,0,1.2,0.2,2,0.1c0,0.1,0,0.2,0.1,0.3c0.2,0.5,0.7,0.9,1.3,1c0.1,0,0.1,0,0.2,0c0.8-0.1,1.6-0.5,2.1-1.1l0,0c0.4-0.4,0.9-0.7,1.4-0.9c0.6-0.3,1-0.5,1.1-1C20.3,18.6,20.1,18.2,19.7,17.6z M12.8,4.8c0.6,0.1,1.1,0.6,1,1.2c0,0.3-0.1,0.6-0.3,0.9c0,0,0,0-0.1,0c-0.2-0.1-0.3-0.1-0.4-0.2c0.1-0.1,0.1-0.3,0.2-0.5c0-0.4-0.2-0.7-0.4-0.7c-0.3,0-0.5,0.3-0.5,0.7c0,0,0,0.1,0,0.1c-0.1-0.1-0.3-0.1-0.4-0.2c0,0,0-0.1,0-0.1C11.8,5.5,12.2,4.9,12.8,4.8z M12.5,6.8c0.1,0.1,0.3,0.2,0.4,0.2c0.1,0,0.3,0.1,0.4,0.2c0.2,0.1,0.4,0.2,0.4,0.5c0,0.3-0.3,0.6-0.9,0.8c-0.2,0.1-0.3,0.1-0.4,0.2c-0.3,0.2-0.6,0.3-1,0.3c-0.3,0-0.6-0.2-0.8-0.4c-0.1-0.1-0.2-0.2-0.4-0.3C10.1,8.2,9.9,8,9.8,7.7c0-0.1,0.1-0.2,0.2-0.3c0.3-0.2,0.4-0.3,0.5-0.4l0.1-0.1c0.2-0.3,0.6-0.5,1-0.5C11.9,6.5,12.2,6.6,12.5,6.8z M10.4,5c0.4,0,0.7,0.4,0.8,1.1c0,0.1,0,0.1,0,0.2c-0.1,0-0.3,0.1-0.4,0.2c0,0,0-0.1,0-0.2c0-0.3-0.2-0.6-0.4-0.5c-0.2,0-0.3,0.3-0.3,0.6c0,0.2,0.1,0.3,0.2,0.4l0,0c0,0-0.1,0.1-0.2,0.1C9.9,6.7,9.7,6.4,9.7,6.1C9.7,5.5,10,5,10.4,5z M9.4,21.1c-0.7,0.3-1.6,0.2-2.2-0.2c-0.6-0.3-1.1-0.4-1.8-0.4c-0.5-0.1-1-0.1-1.1-0.3c-0.1-0.2-0.1-0.5,0.1-1c0.1-0.3,0.1-0.6,0-0.9c-0.1-0.3-0.1-0.5,0-0.8C4.5,17.2,4.7,17.1,5,17c0.3-0.1,0.5-0.2,0.7-0.4c0.1-0.1,0.2-0.2,0.3-0.4c0.3-0.4,0.5-0.6,0.8-0.6c0.6,0.1,1.1,1,1.5,1.9c0.2,0.3,0.4,0.7,0.7,1c0.4,0.5,0.9,1.2,0.9,1.6C9.9,20.6,9.7,20.9,9.4,21.1z M14.3,18.9c0,0.1,0,0.1-0.1,0.2c-1.2,0.9-2.8,1-4.1,0.3c-0.2-0.3-0.4-0.6-0.6-0.9c0.9-0.1,0.7-1.3-1.2-2.5c-2-1.3-0.6-3.7,0.1-4.8c0.1-0.1,0.1,0-0.3,0.8c-0.3,0.6-0.9,2.1-0.1,3.2c0-0.8,0.2-1.6,0.5-2.4c0.7-1.3,1.2-2.8,1.5-4.3c0.1,0.1,0.1,0.1,0.2,0.1c0.1,0.1,0.2,0.2,0.3,0.2c0.2,0.3,0.6,0.4,0.9,0.4c0,0,0.1,0,0.1,0c0.4,0,0.8-0.1,1.1-0.4c0.1-0.1,0.2-0.2,0.4-0.2c0.3-0.1,0.6-0.3,0.9-0.6c0.4,1.3,0.8,2.5,1.4,3.6c0.4,0.8,0.7,1.6,0.9,2.5c0.3,0,0.7,0.1,1,0.3c0.8,0.4,1.1,0.7,1,1.2c-0.1,0-0.1,0-0.2,0c0-0.3-0.2-0.6-0.9-0.9c-0.7-0.3-1.3-0.3-1.5,0.4c-0.1,0-0.2,0.1-0.3,0.1c-0.8,0.4-0.8,1.5-0.9,2.6C14.5,18.2,14.4,18.5,14.3,18.9z M18.9,19.5c-0.6,0.2-1.1,0.6-1.5,1.1c-0.4,0.6-1.1,1-1.9,0.9c-0.4,0-0.8-0.3-0.9-0.7c-0.1-0.6-0.1-1.2,0.2-1.8c0.1-0.4,0.2-0.7,0.3-1.1c0.1-1.2,0.1-1.9,0.6-2.2h0c0,0.5,0.3,0.8,0.7,1c0.5,0,1-0.1,1.4-0.5c0.1,0,0.1,0,0.2,0c0.3,0,0.5,0,0.7,0.2c0.2,0.2,0.3,0.5,0.3,0.7c0,0.3,0.2,0.6,0.3,0.9c0.5,0.5,0.5,0.8,0.5,0.9C19.7,19.1,19.3,19.3,18.9,19.5z M9.9,7.5c-0.1,0-0.1,0-0.1,0.1c0,0,0,0.1,0.1,0.1c0,0,0,0,0,0c0.1,0,0.1,0.1,0.1,0.1c0.3,0.4,0.8,0.6,1.4,0.7c0.5-0.1,1-0.2,1.5-0.6c0.2-0.1,0.4-0.2,0.6-0.3c0.1,0,0.1-0.1,0.1-0.1c0-0.1,0-0.1-0.1-0.1l0,0c-0.2,0.1-0.5,0.2-0.7,0.3c-0.4,0.3-0.9,0.5-1.4,0.5c-0.5,0-0.9-0.3-1.2-0.6C10.1,7.6,10,7.5,9.9,7.5z'
fill='currentColor'
/>
</svg>
);
}
return (
<div>
<Icon svg={<CustomIcon />} />
</div>
);
};
export default LinuxDoIcon;

View File

@@ -0,0 +1,71 @@
import React, { useContext, useEffect, useState } from 'react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers';
import { UserContext } from '../context/User';
const LinuxDoOAuth = () => {
const [searchParams, setSearchParams] = useSearchParams();
const [userState, userDispatch] = useContext(UserContext);
const [prompt, setPrompt] = useState('处理中...');
const [processing, setProcessing] = useState(true);
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
let aff = localStorage.getItem('aff');
const res = await API.get(
`/api/oauth/linuxdo?code=${code}&state=${state}&aff=${aff}`,
);
const { success, message, data } = res.data;
if (success) {
localStorage.removeItem('aff');
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/setting');
} else {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/');
}
} else {
showError(message);
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
navigate('/setting'); // in case this is failed to bind GitHub
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, state, count);
}
};
useEffect(() => {
let error = searchParams.get('error');
if (error) {
let errorDescription = searchParams.get('error_description');
showError(`授权错误:${error}: ${errorDescription}`);
navigate('/setting');
return;
}
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (
<Segment style={{ minHeight: '300px' }}>
<Dimmer active inverted>
<Loader size='large'>{prompt}</Loader>
</Dimmer>
</Segment>
);
};
export default LinuxDoOAuth;

View File

@@ -2,7 +2,7 @@ import React, { useContext, useEffect, useState } from 'react';
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
import { onGitHubOAuthClicked } from './utils';
import { onGitHubOAuthClicked, onLinuxDoOAuthClicked } from './utils';
import Turnstile from 'react-turnstile';
import {
Button,
@@ -18,6 +18,7 @@ import Text from '@douyinfe/semi-ui/lib/es/typography/text';
import TelegramLoginButton from 'react-telegram-login';
import { IconGithubLogo } from '@douyinfe/semi-icons';
import LinuxDoIcon from './LinuxDoIcon';
import WeChatIcon from './WeChatIcon';
const LoginForm = () => {
@@ -207,6 +208,7 @@ const LoginForm = () => {
</Text>
</div>
{status.github_oauth ||
status.linuxdo_oauth ||
status.wechat_login ||
status.telegram_oauth ? (
<>
@@ -231,6 +233,18 @@ const LoginForm = () => {
) : (
<></>
)}
{status.linuxdo_oauth ? (
<Button
type='primary'
icon={<LinuxDoIcon />}
style={{ color: '#000' }}
onClick={() =>
onLinuxDoOAuthClicked(status.linuxdo_client_id)
}
/>
) : (
<></>
)}
{status.wechat_login ? (
<Button
type='primary'
@@ -241,25 +255,16 @@ const LoginForm = () => {
) : (
<></>
)}
{status.telegram_oauth ? (
<TelegramLoginButton
dataOnauth={onTelegramLoginClicked}
botName={status.telegram_bot_name}
/>
) : (
<></>
)}
</div>
{status.telegram_oauth ? (
<>
<div
style={{
display: 'flex',
justifyContent: 'center',
marginTop: 5,
}}
>
<TelegramLoginButton
dataOnauth={onTelegramLoginClicked}
botName={status.telegram_bot_name}
/>
</div>
</>
) : (
<></>
)}
</>
) : (
<></>

View File

@@ -19,15 +19,9 @@ import {
Spin,
Table,
Tag,
Tooltip,
} from '@douyinfe/semi-ui';
import { ITEMS_PER_PAGE } from '../constants';
import {
renderModelPrice,
renderNumber,
renderQuota,
stringToColor,
} from '../helpers/render';
import { renderNumber, renderQuota, stringToColor } from '../helpers/render';
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
const { Header } = Layout;
@@ -298,44 +292,16 @@ const LogsTable = () => {
title: '详情',
dataIndex: 'content',
render: (text, record, index) => {
if (record.other === '') {
return (
<Paragraph
ellipsis={{
rows: 2,
showTooltip: {
type: 'popover',
opts: { style: { width: 240 } },
},
}}
style={{ maxWidth: 240 }}
>
{text}
</Paragraph>
);
}
let other = JSON.parse(record.other);
let content = renderModelPrice(
other.model_ratio,
other.model_price,
other.completion_ratio,
other.group_ratio,
);
return (
<Tooltip content={content}>
<Paragraph
ellipsis={{
rows: 2,
showTooltip: {
type: 'popover',
opts: { style: { width: 240 } },
},
}}
style={{ maxWidth: 240 }}
>
{text}
</Paragraph>
</Tooltip>
<Paragraph
ellipsis={{
rows: 2,
showTooltip: { type: 'popover', opts: { style: { width: 240 } } },
}}
style={{ maxWidth: 240 }}
>
{text}
</Paragraph>
);
},
},
@@ -346,7 +312,7 @@ const LogsTable = () => {
const [loading, setLoading] = useState(false);
const [loadingStat, setLoadingStat] = useState(false);
const [activePage, setActivePage] = useState(1);
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
const [logCount, setLogCount] = useState(0);
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
@@ -443,14 +409,14 @@ const LogsTable = () => {
}
};
const setLogsFormat = (logs) => {
const setLogsFormat = (logs, total) => {
for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
logs[i].key = '' + logs[i].id;
}
// data.key = '' + data.id
setLogs(logs);
setLogCount(logs.length + ITEMS_PER_PAGE);
setLogCount(total);
// console.log(logCount);
};
@@ -466,14 +432,14 @@ const LogsTable = () => {
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
const res = await API.get(url);
const { success, message, data } = res.data;
const { success, message, total, data } = res.data;
if (success) {
if (startIdx === 0) {
setLogsFormat(data);
setLogsFormat(data, total);
} else {
let newLogs = [...logs];
newLogs.splice(startIdx * pageSize, data.length, ...data);
setLogsFormat(newLogs);
setLogsFormat(newLogs, total);
}
} else {
showError(message);
@@ -640,7 +606,9 @@ const LogsTable = () => {
type='primary'
htmlType='submit'
className='btn-margin-right'
onClick={refresh}
onClick={() => {
refresh(logType).then();
}}
loading={loading}
>
查询

View File

@@ -236,31 +236,6 @@ const renderTimestamp = (timestampInSeconds) => {
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
};
// 修改renderDuration函数以包含颜色逻辑
function renderDuration(submit_time, finishTime) {
// 确保startTime和finishTime都是有效的时间戳
if (!submit_time || !finishTime) return 'N/A';
// 将时间戳转换为Date对象
const start = new Date(submit_time);
const finish = new Date(finishTime);
// 计算时间差(毫秒)
const durationMs = finish - start;
// 将时间差转换为秒,并保留一位小数
const durationSec = (durationMs / 1000).toFixed(1);
// 设置颜色大于60秒则为红色小于等于60秒则为绿色
const color = durationSec > 60 ? 'red' : 'green';
// 返回带有样式的颜色标签
return (
<Tag color={color} size="large">
{durationSec}
</Tag>
);
}
const LogsTable = () => {
const [isModalOpen, setIsModalOpen] = useState(false);
@@ -273,15 +248,6 @@ const LogsTable = () => {
return <div>{renderTimestamp(text / 1000)}</div>;
},
},
{
title: '花费时间',
dataIndex: 'finish_time', // 以finish_time作为dataIndex
key: 'finish_time',
render: (finish, record) => {
// 假设record.start_time是存在的并且finish是完成时间的时间戳
return renderDuration(record.submit_time, finish);
},
},
{
title: '渠道',
dataIndex: 'channel_id',

View File

@@ -8,8 +8,6 @@ import {
verifyJSON,
} from '../helpers';
import { useTheme } from '../context/Theme';
const OperationSetting = () => {
let now = new Date();
let [inputs, setInputs] = useState({
@@ -20,7 +18,6 @@ const OperationSetting = () => {
PreConsumedQuota: 0,
StreamCacheQueueLength: 0,
ModelRatio: '',
CompletionRatio: '',
ModelPrice: '',
GroupRatio: '',
TopUpLink: '',
@@ -39,9 +36,6 @@ const OperationSetting = () => {
StopOnSensitiveEnabled: '',
SensitiveWords: '',
MjNotifyEnabled: '',
MjAccountFilterEnabled: '',
MjModeClearEnabled: '',
MjForwardUrlEnabled: '',
DrawingEnabled: '',
DataExportEnabled: '',
DataExportDefaultTime: 'hour',
@@ -69,7 +63,6 @@ const OperationSetting = () => {
if (
item.key === 'ModelRatio' ||
item.key === 'GroupRatio' ||
item.key === 'CompletionRatio' ||
item.key === 'ModelPrice'
) {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
@@ -83,9 +76,6 @@ const OperationSetting = () => {
}
};
const theme = useTheme();
const isDark = theme === 'dark';
useEffect(() => {
getOptions().then();
}, []);
@@ -159,13 +149,6 @@ const OperationSetting = () => {
}
await updateOption('ModelRatio', inputs.ModelRatio);
}
if (originInputs['CompletionRatio'] !== inputs.CompletionRatio) {
if (!verifyJSON(inputs.CompletionRatio)) {
showError('模型补全倍率不是合法的 JSON 字符串');
return;
}
await updateOption('CompletionRatio', inputs.CompletionRatio);
}
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
@@ -235,10 +218,8 @@ const OperationSetting = () => {
return (
<Grid columns={1}>
<Grid.Column>
<Form loading={loading} inverted={isDark}>
<Header as='h3' inverted={isDark}>
通用设置
</Header>
<Form loading={loading}>
<Header as='h3'>通用设置</Header>
<Form.Group widths={4}>
<Form.Input
label='充值链接'
@@ -317,9 +298,7 @@ const OperationSetting = () => {
保存通用设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
绘图设置
</Header>
<Header as='h3'>绘图设置</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.DrawingEnabled === 'true'}
@@ -333,29 +312,9 @@ const OperationSetting = () => {
name='MjNotifyEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.MjAccountFilterEnabled === 'true'}
label='允许AccountFilter参数'
name='MjAccountFilterEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.MjForwardUrlEnabled === 'true'}
label='开启之后将上游地址替换为服务器地址'
name='MjForwardUrlEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.MjModeClearEnabled === 'true'}
label='开启之后会清除用户提示词中的--fast、--relax以及--turbo参数'
name='MjModeClearEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Divider />
<Header as='h3' inverted={isDark}>
屏蔽词过滤设置
</Header>
<Header as='h3'>屏蔽词过滤设置</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.CheckSensitiveEnabled === 'true'}
@@ -415,9 +374,7 @@ const OperationSetting = () => {
保存屏蔽词设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
日志设置
</Header>
<Header as='h3'>日志设置</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
@@ -445,9 +402,7 @@ const OperationSetting = () => {
清理历史日志
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
数据看板
</Header>
<Header as='h3'>数据看板</Header>
<Form.Checkbox
checked={inputs.DataExportEnabled === 'true'}
label='启用数据看板(实验性)'
@@ -477,9 +432,7 @@ const OperationSetting = () => {
/>
</Form.Group>
<Divider />
<Header as='h3' inverted={isDark}>
监控设置
</Header>
<Header as='h3'>监控设置</Header>
<Form.Group widths={3}>
<Form.Input
label='最长响应时间'
@@ -524,9 +477,7 @@ const OperationSetting = () => {
保存监控设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
额度设置
</Header>
<Header as='h3'>额度设置</Header>
<Form.Group widths={4}>
<Form.Input
label='新用户初始额度'
@@ -577,9 +528,7 @@ const OperationSetting = () => {
保存额度设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
倍率设置
</Header>
<Header as='h3'>倍率设置</Header>
<Form.Group widths='equal'>
<Form.TextArea
label='模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)'
@@ -602,17 +551,6 @@ const OperationSetting = () => {
placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
label='模型补全倍率(仅对自定义模型有效)'
name='CompletionRatio'
onChange={handleInputChange}
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
value={inputs.CompletionRatio}
placeholder='为一个 JSON 文本,键为分组名称,值为倍率'
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
label='分组倍率'

View File

@@ -10,7 +10,7 @@ import {
} from '../helpers';
import Turnstile from 'react-turnstile';
import { UserContext } from '../context/User';
import { onGitHubOAuthClicked } from './utils';
import { onGitHubOAuthClicked, onLinuxDoOAuthClicked } from './utils';
import {
Avatar,
Banner,
@@ -519,6 +519,39 @@ const PersonalSetting = () => {
</div>
</div>
</div>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>LINUX DO</Typography.Text>
<div
style={{ display: 'flex', justifyContent: 'space-between' }}
>
<div>
<Input
value={
userState.user && userState.user.linuxdo_id !== ''
? userState.user.linuxdo_id +
'' +
userState.user.linuxdo_level +
'级)'
: '未绑定'
}
readonly={true}
></Input>
</div>
<div>
<Button
onClick={() => {
onLinuxDoOAuthClicked(status.linuxdo_client_id);
}}
disabled={
(userState.user && userState.user.linuxdo_id !== '') ||
!status.linuxdo_oauth
}
>
{status.linuxdo_oauth ? '绑定' : '未启用'}
</Button>
</div>
</div>
</div>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>Telegram</Typography.Text>

View File

@@ -77,6 +77,8 @@ const RegisterForm = () => {
);
const { success, message } = res.data;
if (success) {
localStorage.removeItem('aff');
navigate('/login');
showSuccess('注册成功!');
} else {

View File

@@ -10,8 +10,6 @@ import {
} from 'semantic-ui-react';
import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers';
import { useTheme } from '../context/Theme';
const SystemSetting = () => {
let [inputs, setInputs] = useState({
PasswordLoginEnabled: '',
@@ -20,6 +18,10 @@ const SystemSetting = () => {
GitHubOAuthEnabled: '',
GitHubClientId: '',
GitHubClientSecret: '',
LinuxDoOAuthEnabled: '',
LinuxDoClientId: '',
LinuxDoClientSecret: '',
LinuxDoMinLevel: 0,
Notice: '',
SMTPServer: '',
SMTPPort: '',
@@ -27,13 +29,13 @@ const SystemSetting = () => {
SMTPFrom: '',
SMTPToken: '',
ServerAddress: '',
EpayId: '',
EpayKey: '',
Price: 7.3,
MinTopUp: 1,
StripeApiSecret: '',
StripeWebhookSecret: '',
StripePriceId: '',
PaymentEnabled: false,
StripeUnitPrice: 8.0,
MinTopUp: 5,
TopupGroupRatio: '',
PayAddress: '',
CustomCallbackAddress: '',
Footer: '',
WeChatAuthEnabled: '',
WeChatServerAddress: '',
@@ -43,8 +45,8 @@ const SystemSetting = () => {
TurnstileSiteKey: '',
TurnstileSecretKey: '',
RegisterEnabled: '',
UserSelfDeletionEnabled: false,
EmailDomainRestrictionEnabled: '',
EmailAliasRestrictionEnabled: '',
SMTPSSLEnabled: '',
EmailDomainWhitelist: [],
// telegram login
@@ -59,9 +61,6 @@ const SystemSetting = () => {
const [showPasswordWarningModal, setShowPasswordWarningModal] =
useState(false);
const theme = useTheme();
const isDark = theme === 'dark';
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
@@ -101,13 +100,15 @@ const SystemSetting = () => {
case 'PasswordRegisterEnabled':
case 'EmailVerificationEnabled':
case 'GitHubOAuthEnabled':
case 'LinuxDoOAuthEnabled':
case 'WeChatAuthEnabled':
case 'TelegramOAuthEnabled':
case 'TurnstileCheckEnabled':
case 'EmailDomainRestrictionEnabled':
case 'EmailAliasRestrictionEnabled':
case 'SMTPSSLEnabled':
case 'RegisterEnabled':
case 'UserSelfDeletionEnabled':
case 'PaymentEnabled':
value = inputs[key] === 'true' ? 'false' : 'true';
break;
default:
@@ -122,9 +123,6 @@ const SystemSetting = () => {
if (key === 'EmailDomainWhitelist') {
value = value.split(',');
}
if (key === 'Price') {
value = parseFloat(value);
}
setInputs((inputs) => ({
...inputs,
[key]: value,
@@ -145,12 +143,16 @@ const SystemSetting = () => {
name === 'Notice' ||
(name.startsWith('SMTP') && name !== 'SMTPSSLEnabled') ||
name === 'ServerAddress' ||
name === 'EpayId' ||
name === 'EpayKey' ||
name === 'Price' ||
name === 'PayAddress' ||
name === 'StripeApiSecret' ||
name === 'StripeWebhookSecret' ||
name === 'StripePriceId' ||
name === 'StripeUnitPrice' ||
name === 'MinTopUp' ||
name === 'GitHubClientId' ||
name === 'GitHubClientSecret' ||
name === 'LinuxDoClientId' ||
name === 'LinuxDoClientSecret' ||
name === 'LinuxDoMinLevel' ||
name === 'WeChatServerAddress' ||
name === 'WeChatServerToken' ||
name === 'WeChatAccountQRCodeImageURL' ||
@@ -172,7 +174,7 @@ const SystemSetting = () => {
await updateOption('ServerAddress', ServerAddress);
};
const submitPayAddress = async () => {
const submitPaymentConfig = async () => {
if (inputs.ServerAddress === '') {
showError('请先填写服务器地址');
return;
@@ -184,15 +186,31 @@ const SystemSetting = () => {
}
await updateOption('TopupGroupRatio', inputs.TopupGroupRatio);
}
let PayAddress = removeTrailingSlash(inputs.PayAddress);
await updateOption('PayAddress', PayAddress);
if (inputs.EpayId !== '') {
await updateOption('EpayId', inputs.EpayId);
let stripeApiSecret = removeTrailingSlash(inputs.StripeApiSecret);
if (stripeApiSecret && !stripeApiSecret.startsWith('sk_')) {
showError('输入了无效的Stripe API密钥');
return;
}
if (inputs.EpayKey !== undefined && inputs.EpayKey !== '') {
await updateOption('EpayKey', inputs.EpayKey);
stripeApiSecret && (await updateOption('StripeApiSecret', stripeApiSecret));
let stripeWebhookSecret = removeTrailingSlash(inputs.StripeWebhookSecret);
if (stripeWebhookSecret && !stripeWebhookSecret.startsWith('whsec_')) {
showError('输入了无效的Stripe Webhook签名密钥');
return;
}
await updateOption('Price', '' + inputs.Price);
stripeWebhookSecret &&
(await updateOption('StripeWebhookSecret', stripeWebhookSecret));
let stripePriceId = removeTrailingSlash(inputs.StripePriceId);
if (stripePriceId && !stripePriceId.startsWith('price_')) {
showError('输入了无效的Stripe 物品价格ID');
return;
}
await updateOption('StripePriceId', stripePriceId);
await updateOption('PaymentEnable', inputs.PaymentEnabled);
await updateOption('StripeUnitPrice', inputs.StripeUnitPrice);
await updateOption('MinTopUp', inputs.MinTopUp);
};
const submitSMTP = async () => {
@@ -268,6 +286,21 @@ const SystemSetting = () => {
}
};
const submitLinuxDoOAuth = async () => {
if (originInputs['LinuxDoClientId'] !== inputs.LinuxDoClientId) {
await updateOption('LinuxDoClientId', inputs.LinuxDoClientId);
}
if (
originInputs['LinuxDoClientSecret'] !== inputs.LinuxDoClientSecret &&
inputs.LinuxDoClientSecret !== ''
) {
await updateOption('LinuxDoClientSecret', inputs.LinuxDoClientSecret);
}
if (originInputs['LinuxDoMinLevel'] !== inputs.LinuxDoMinLevel) {
await updateOption('LinuxDoMinLevel', inputs.LinuxDoMinLevel);
}
};
const submitTelegramSettings = async () => {
// await updateOption('TelegramOAuthEnabled', inputs.TelegramOAuthEnabled);
await updateOption('TelegramBotToken', inputs.TelegramBotToken);
@@ -311,10 +344,8 @@ const SystemSetting = () => {
return (
<Grid columns={1}>
<Grid.Column>
<Form loading={loading} inverted={isDark}>
<Header as='h3' inverted={isDark}>
通用设置
</Header>
<Form loading={loading}>
<Header as='h3'>通用设置</Header>
<Form.Group widths='equal'>
<Form.Input
label='服务器地址'
@@ -328,53 +359,73 @@ const SystemSetting = () => {
更新服务器地址
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
支付设置当前仅支持易支付接口默认使用上方服务器地址作为回调地址
<Header as='h3'>
支付设置当前仅支持Stripe Checkout
<Header.Subheader>
密钥Webhook 等设置请
<a
href='https://dashboard.stripe.com/developers'
target='_blank'
rel='noreferrer'
>
点击此处
</a>
进行设置最好先在
<a
href='https://dashboard.stripe.com/test/developers'
target='_blank'
rel='noreferrer'
>
测试环境
</a>
进行测试
</Header.Subheader>
</Header>
<Message>
Webhook
<code>{`${inputs.ServerAddress}/api/stripe/webhook`}</code>
需要包含事件<code>checkout.session.completed</code> {' '}
<code>checkout.session.expired</code>
</Message>
<Form.Group widths='equal'>
<Form.Input
label='支付地址,不填写则不启用在线支付'
placeholder='例如https://yourdomain.com'
value={inputs.PayAddress}
name='PayAddress'
label='API密钥'
placeholder='sk_xxx的Stripe密钥敏感信息不显示'
value={inputs.StripeApiSecret}
name='StripeApiSecret'
onChange={handleInputChange}
/>
<Form.Input
label='易支付商户ID'
placeholder='例如0001'
value={inputs.EpayId}
name='EpayId'
label='Webhook签名密钥'
placeholder='whsec_xxx的Webhook签名密钥敏感信息不显示'
value={inputs.StripeWebhookSecret}
name='StripeWebhookSecret'
onChange={handleInputChange}
/>
<Form.Input
label='易支付商户密钥'
placeholder='敏感信息不会发送到前端显示'
value={inputs.EpayKey}
name='EpayKey'
label='商品价格ID'
placeholder='price_xxx的商品价格ID新建产品后可获得'
value={inputs.StripePriceId}
name='StripePriceId'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.Input
label='回调地址,不填写则使用上方服务器地址作为回调地址'
placeholder='例如https://yourdomain.com'
value={inputs.CustomCallbackAddress}
name='CustomCallbackAddress'
onChange={handleInputChange}
/>
<Form.Input
label='充值价格x元/美金)'
placeholder='例如7就是7元/美金'
value={inputs.Price}
name='Price'
label='商品单价(元)'
placeholder='商品的人民币价格'
value={inputs.StripeUnitPrice}
name='StripeUnitPrice'
type={'number'}
min={0}
onChange={handleInputChange}
/>
<Form.Input
label='最低充值美元数量(以美金为单位,如果使用额度请自行换算!)'
placeholder='例如2就是最低充值2$'
label='最低充值数量'
placeholder='例如2就是最低充值2件商品'
value={inputs.MinTopUp}
name='MinTopUp'
type={'number'}
min={1}
onChange={handleInputChange}
/>
@@ -390,11 +441,19 @@ const SystemSetting = () => {
placeholder='为一个 JSON 文本,键为组名称,值为倍率'
/>
</Form.Group>
<Form.Button onClick={submitPayAddress}>更新支付设置</Form.Button>
<Form.Group inline>
<Form.Button onClick={submitPaymentConfig}>
更新支付设置
</Form.Button>
<Form.Checkbox
checked={inputs.PaymentEnabled === 'true'}
label='开启在线支付'
name='PaymentEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Divider />
<Header as='h3' inverted={isDark}>
配置登录注册
</Header>
<Header as='h3'>配置登录注册</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.PasswordLoginEnabled === 'true'}
@@ -449,6 +508,12 @@ const SystemSetting = () => {
name='GitHubOAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.LinuxDoOAuthEnabled === 'true'}
label='允许通过 LINUX DO 账户登录 & 注册'
name='LinuxDoOAuthEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.WeChatAuthEnabled === 'true'}
label='允许通过微信登录 & 注册'
@@ -475,9 +540,15 @@ const SystemSetting = () => {
name='TurnstileCheckEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.UserSelfDeletionEnabled === 'true'}
label='允许用户自行删除账户'
name='UserSelfDeletionEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Divider />
<Header as='h3' inverted={isDark}>
<Header as='h3'>
配置邮箱域名白名单
<Header.Subheader>
用以防止恶意用户利用临时邮箱批量注册
@@ -491,14 +562,6 @@ const SystemSetting = () => {
checked={inputs.EmailDomainRestrictionEnabled === 'true'}
/>
</Form.Group>
<Form.Group widths={3}>
<Form.Checkbox
label='启用邮箱别名限制例如ab.cd@gmail.com'
name='EmailAliasRestrictionEnabled'
onChange={handleInputChange}
checked={inputs.EmailAliasRestrictionEnabled === 'true'}
/>
</Form.Group>
<Form.Group widths={2}>
<Form.Dropdown
label='允许的邮箱域名'
@@ -542,7 +605,7 @@ const SystemSetting = () => {
保存邮箱域名白名单设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
<Header as='h3'>
配置 SMTP
<Header.Subheader>用以支持系统的邮件发送</Header.Subheader>
</Header>
@@ -601,7 +664,7 @@ const SystemSetting = () => {
</Form.Group>
<Form.Button onClick={submitSMTP}>保存 SMTP 设置</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
<Header as='h3'>
配置 GitHub OAuth App
<Header.Subheader>
用以支持通过 GitHub 进行登录注册
@@ -643,7 +706,59 @@ const SystemSetting = () => {
保存 GitHub OAuth 设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
<Header as='h3'>
配置 LINUX DO Oauth
<Header.Subheader>
用以支持通过 LINUX DO 进行登录注册
<a
href='https://connect.linux.do'
target='_blank'
rel='noreferrer'
>
点击此处
</a>
管理你的 LINUX DO OAuth
</Header.Subheader>
</Header>
<Message>
Homepage URL <code>{inputs.ServerAddress}</code>
Authorization callback URL {' '}
<code>{`${inputs.ServerAddress}/oauth/linuxdo`}</code>
</Message>
<Form.Group widths={3}>
<Form.Input
label='LINUX DO Client ID'
name='LinuxDoClientId'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.LinuxDoClientId}
placeholder='输入你注册的 LINUX DO OAuth 的 ID'
/>
<Form.Input
label='LINUX DO Client Secret'
name='LinuxDoClientSecret'
onChange={handleInputChange}
type='password'
autoComplete='new-password'
value={inputs.LinuxDoClientSecret}
placeholder='敏感信息不会发送到前端显示'
/>
<Form.Input
label='限制最低信任等级'
name='LinuxDoMinLevel'
onChange={handleInputChange}
type='number'
min={0}
max={4}
value={inputs.LinuxDoMinLevel}
placeholder='输入允许使用的最低 LINUX DO 信任等级'
/>
</Form.Group>
<Form.Button onClick={submitLinuxDoOAuth}>
保存 LINUX DO OAuth 设置
</Form.Button>
<Divider />
<Header as='h3'>
配置 WeChat Server
<Header.Subheader>
用以支持通过微信进行登录注册
@@ -688,9 +803,7 @@ const SystemSetting = () => {
保存 WeChat Server 设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
配置 Telegram 登录
</Header>
<Header as='h3'>配置 Telegram 登录</Header>
<Form.Group inline>
<Form.Input
label='Telegram Bot Token'
@@ -711,7 +824,7 @@ const SystemSetting = () => {
保存 Telegram 登录设置
</Form.Button>
<Divider />
<Header as='h3' inverted={isDark}>
<Header as='h3'>
配置 Turnstile
<Header.Subheader>
用以支持用户校验

View File

@@ -210,21 +210,39 @@ const UsersTable = () => {
</Button>
</>
)}
<Popconfirm
title='确定是否要删除此用户?'
content='硬删除,此修改将不可逆'
okType={'danger'}
position={'left'}
onConfirm={() => {
manageUser(record.username, 'delete', record).then(() => {
removeRecord(record.id);
});
}}
>
<Button theme='light' type='danger' style={{ marginRight: 1 }}>
删除
</Button>
</Popconfirm>
{record.DeletedAt !== null ? (
<Popconfirm
title='确定是否要删除此用户?'
content='硬删除,此修改将不可逆'
okType={'danger'}
position={'left'}
onConfirm={() => {
hardDeleteUser(record.id).then(() => {
removeRecord(record.id);
});
}}
>
<Button theme='light' type='danger' style={{ marginRight: 1 }}>
永久删除
</Button>
</Popconfirm>
) : (
<Popconfirm
title='确定是否要删除此用户?'
content='软删除,数据依然留底'
okType={'danger'}
position={'left'}
onConfirm={() => {
manageUser(record.username, 'delete', record).then(() => {
record.DeletedAt = new Date();
});
}}
>
<Button theme='light' type='danger' style={{ marginRight: 1 }}>
删除
</Button>
</Popconfirm>
)}
</div>
),
},
@@ -235,8 +253,6 @@ const UsersTable = () => {
const [activePage, setActivePage] = useState(1);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [searchGroup, setSearchGroup] = useState('');
const [groupOptions, setGroupOptions] = useState([]);
const [userCount, setUserCount] = useState(ITEMS_PER_PAGE);
const [showAddUser, setShowAddUser] = useState(false);
const [showEditUser, setShowEditUser] = useState(false);
@@ -300,7 +316,6 @@ const UsersTable = () => {
.catch((reason) => {
showError(reason);
});
fetchGroups().then();
}, []);
const manageUser = async (username, action, record) => {
@@ -321,6 +336,18 @@ const UsersTable = () => {
setUsers(newUsers);
} else {
showError(message);
throw new Error(message);
}
};
const hardDeleteUser = async (userId) => {
const res = await API.delete('/api/user/' + userId);
const { success, message } = res.data;
if (success) {
showSuccess('操作成功完成!');
} else {
showError(message);
throw new Error(message);
}
};
@@ -343,15 +370,15 @@ const UsersTable = () => {
}
};
const searchUsers = async (searchKeyword, searchGroup) => {
if (searchKeyword === '' && searchGroup === '') {
const searchUsers = async () => {
if (searchKeyword === '') {
// if keyword is blank, load files instead.
await loadUsers(0);
setActivePage(1);
return;
}
setSearching(true);
const res = await API.get(`/api/user/search?keyword=${searchKeyword}&group=${searchGroup}`);
const res = await API.get(`/api/user/search?keyword=${searchKeyword}`);
const { success, message, data } = res.data;
if (success) {
setUsers(data);
@@ -412,25 +439,6 @@ const UsersTable = () => {
}
};
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group/`);
// add 'all' option
// res.data.data.unshift('all');
if (res === undefined) {
return;
}
setGroupOptions(
res.data.data.map((group) => ({
label: group,
value: group,
})),
);
} catch (error) {
showError(error.message);
}
};
return (
<>
<AddUser
@@ -444,44 +452,17 @@ const UsersTable = () => {
handleClose={closeEditUser}
editingUser={editingUser}
></EditUser>
<Form
onSubmit={() => {
searchUsers(searchKeyword, searchGroup);
}}
labelPosition='left'
>
<div style={{ display: 'flex' }}>
<Space>
<Form.Input
label='搜索关键字'
icon='search'
field='keyword'
iconPosition='left'
placeholder='搜索用户的 ID用户名显示名称以及邮箱地址 ...'
value={searchKeyword}
loading={searching}
onChange={(value) => handleKeywordChange(value)}
/>
<Form.Select
field='group'
label='分组'
optionList={groupOptions}
onChange={(value) => {
setSearchGroup(value);
searchUsers(searchKeyword, value);
}}
/>
<Button
label='查询'
type='primary'
htmlType='submit'
className='btn-margin-right'
style={{ marginRight: 8 }}
>
查询
</Button>
</Space>
</div>
<Form onSubmit={searchUsers}>
<Form.Input
label='搜索关键字'
icon='search'
field='keyword'
iconPosition='left'
placeholder='搜索用户的 ID用户名显示名称以及邮箱地址 ...'
value={searchKeyword}
loading={searching}
onChange={(value) => handleKeywordChange(value)}
/>
</Form>
<Table

View File

@@ -14,36 +14,11 @@ export async function getOAuthState() {
export async function onGitHubOAuthClicked(github_client_id) {
const state = await getOAuthState();
if (!state) return;
window.open(
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`,
);
location.href = `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`;
}
let channelModels = undefined;
export async function loadChannelModels() {
const res = await API.get('/api/models');
const { success, data } = res.data;
if (!success) {
return;
}
channelModels = data;
localStorage.setItem('channel_models', JSON.stringify(data));
}
export function getChannelModels(type) {
if (channelModels !== undefined && type in channelModels) {
if (!channelModels[type]) {
return [];
}
return channelModels[type];
}
let models = localStorage.getItem('channel_models');
if (!models) {
return [];
}
channelModels = JSON.parse(models);
if (type in channelModels) {
return channelModels[type];
}
return [];
export async function onLinuxDoOAuthClicked(linuxdo_client_id) {
const state = await getOAuthState();
if (!state) return;
location.href = `https://connect.linux.do/oauth2/authorize?client_id=${linuxdo_client_id}&response_type=code&state=${state}&scope=user:profile`;
}

View File

@@ -22,13 +22,6 @@ export const CHANNEL_OPTIONS = [
color: 'indigo',
label: 'Anthropic Claude',
},
{
key: 33,
text: 'AWS Claude',
value: 33,
color: 'indigo',
label: 'AWS Claude',
},
{
key: 3,
text: 'Azure OpenAI',
@@ -50,13 +43,6 @@ export const CHANNEL_OPTIONS = [
color: 'orange',
label: 'Google Gemini',
},
{
key: 34,
text: 'Cohere',
value: 34,
color: 'purple',
label: 'Cohere',
},
{
key: 15,
text: '百度文心千帆',
@@ -86,13 +72,13 @@ export const CHANNEL_OPTIONS = [
label: '智谱 ChatGLM',
},
{
key: 26,
key: 16,
text: '智谱 GLM-4V',
value: 26,
color: 'purple',
label: '智谱 GLM-4V',
},
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
{ key: 16, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
{ key: 31, text: '零一万物', value: 31, color: 'green', label: '零一万物' },

View File

@@ -1,36 +0,0 @@
import { createContext, useCallback, useContext, useState } from 'react';
const ThemeContext = createContext(null);
export const useTheme = () => useContext(ThemeContext);
const SetThemeContext = createContext(null);
export const useSetTheme = () => useContext(SetThemeContext);
export const ThemeProvider = ({ children }) => {
const [theme, _setTheme] = useState(() => {
try {
return localStorage.getItem('theme-mode') || null;
} catch {
return null;
}
});
const setTheme = useCallback((input) => {
_setTheme(input ? 'dark' : 'light');
const body = document.body;
if (!input) {
body.removeAttribute('theme-mode');
localStorage.setItem('theme-mode', 'light');
} else {
body.setAttribute('theme-mode', 'dark');
localStorage.setItem('theme-mode', 'dark');
}
}, []);
return (
<SetThemeContext.Provider value={setTheme}>
<ThemeContext.Provider value={theme}>{children}</ThemeContext.Provider>
</SetThemeContext.Provider>
);
};

View File

@@ -135,32 +135,6 @@ export function renderQuota(quota, digits = 2) {
return renderNumber(quota);
}
export function renderModelPrice(
modelRatio,
modelPrice = -1,
completionRatio,
groupRatio,
) {
// 1 ratio = $0.002 / 1K tokens
if (modelPrice !== -1) {
return '模型价格:$' + modelPrice * groupRatio;
} else {
if (completionRatio === undefined) {
completionRatio = 0;
}
let inputRatioPrice = modelRatio * 0.002 * groupRatio;
let completionRatioPrice =
modelRatio * completionRatio * 0.002 * groupRatio;
return (
'输入:$' +
inputRatioPrice.toFixed(3) +
'/1K tokens补全$' +
completionRatioPrice.toFixed(3) +
'/1K tokens'
);
}
}
export function renderQuotaWithPrompt(quota, digits) {
let displayInCurrency = localStorage.getItem('display_in_currency');
displayInCurrency = displayInCurrency === 'true';
@@ -190,23 +164,24 @@ const colors = [
export const modelColorMap = {
'dall-e': 'rgb(147,112,219)', // 深紫色
// 'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调
'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调
'dall-e-3': 'rgb(153,50,204)', // 介于紫罗兰和洋红之间的色调
midjourney: 'rgb(136,43,180)', // 介于紫罗兰和洋红之间的色调
'gpt-3.5-turbo': 'rgb(184,227,167)', // 浅绿色
// 'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色
'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色
'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿
'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿
'gpt-3.5-turbo-16k': 'rgb(149,252,206)', // 淡橙色
'gpt-3.5-turbo-16k-0613': 'rgb(119,255,214)', // 淡桃色
'gpt-3.5-turbo-16k': 'rgb(252,200,149)', // 淡橙色
'gpt-3.5-turbo-16k-0613': 'rgb(255,181,119)', // 淡桃色
'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色
'gpt-4': 'rgb(135,206,235)', // 天蓝色
// 'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色
'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色
'gpt-4-0613': 'rgb(100,149,237)', // 矢车菊蓝
'gpt-4-1106-preview': 'rgb(30,144,255)', // 道奇蓝
'gpt-4-0125-preview': 'rgb(2,177,236)', // 深天蓝
'gpt-4-turbo-preview': 'rgb(2,177,255)', // 深天蓝
'gpt-4-32k': 'rgb(104,111,238)', // 中紫色
// 'gpt-4-32k-0314': 'rgb(90,105,205)', // 暗灰蓝色
'gpt-4-32k-0314': 'rgb(90,105,205)', // 暗灰蓝色
'gpt-4-32k-0613': 'rgb(61,71,139)', // 暗蓝灰色
'gpt-4-all': 'rgb(65,105,225)', // 皇家蓝
'gpt-4-gizmo-*': 'rgb(0,0,255)', // 纯蓝色
@@ -214,7 +189,7 @@ export const modelColorMap = {
'text-ada-001': 'rgb(255,192,203)', // 粉红色
'text-babbage-001': 'rgb(255,160,122)', // 浅珊瑚色
'text-curie-001': 'rgb(219,112,147)', // 苍紫罗兰色
// 'text-davinci-002': 'rgb(199,21,133)', // 中紫罗兰红色
'text-davinci-002': 'rgb(199,21,133)', // 中紫罗兰红色
'text-davinci-003': 'rgb(219,112,147)', // 苍紫罗兰色与Curie相同表示同一个系列
'text-davinci-edit-001': 'rgb(255,105,180)', // 热粉色
'text-embedding-ada-002': 'rgb(255,182,193)', // 浅粉红
@@ -226,10 +201,6 @@ export const modelColorMap = {
'tts-1-hd': 'rgb(255,215,0)', // 金色
'tts-1-hd-1106': 'rgb(255,223,0)', // 金黄色(略有区别)
'whisper-1': 'rgb(245,245,220)', // 米色
'claude-3-opus-20240229': 'rgb(255,132,31)', // 橙红色
'claude-3-sonnet-20240229': 'rgb(253,135,93)', // 橙色
'claude-3-haiku-20240307': 'rgb(255,175,146)', // 浅橙色
'claude-2.1': 'rgb(255,209,190)', // 浅橙色(略有区别)
};
export function stringToColor(str) {

Some files were not shown because too many files have changed in this diff Show More