mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-17 19:13:42 +08:00
Compare commits
145 Commits
v0.2.7.7-a
...
v0.2.10-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27b8495698 | ||
|
|
333849429b | ||
|
|
50eab6b4e4 | ||
|
|
ed972eef06 | ||
|
|
c6ff785a83 | ||
|
|
2e734e0c37 | ||
|
|
af33f36c7b | ||
|
|
3aa86a8cd9 | ||
|
|
af7fecbfa7 | ||
|
|
3fbdd502b6 | ||
|
|
052bc2075b | ||
|
|
5f3798053f | ||
|
|
e31022c676 | ||
|
|
fff7609f06 | ||
|
|
9032b5cfbf | ||
|
|
131453dac8 | ||
|
|
ed948c121a | ||
|
|
a03cd15505 | ||
|
|
02f5137781 | ||
|
|
e6df0ed20c | ||
|
|
f505afdc10 | ||
|
|
feb1d76942 | ||
|
|
6263616cd9 | ||
|
|
9b14c5da63 | ||
|
|
6bbf1d4843 | ||
|
|
13c993d87e | ||
|
|
cb73889353 | ||
|
|
804aad3f37 | ||
|
|
3af62a3efa | ||
|
|
be54369c12 | ||
|
|
0cbf8e07e7 | ||
|
|
1675679be9 | ||
|
|
0b5f2a7089 | ||
|
|
b5bb708072 | ||
|
|
2650ec9b59 | ||
|
|
d168a685c1 | ||
|
|
a9e3555cac | ||
|
|
8d83630d85 | ||
|
|
a60f209c85 | ||
|
|
a0d20896b3 | ||
|
|
5cab06d1ce | ||
|
|
e3b3fdec48 | ||
|
|
5863aa8061 | ||
|
|
208bc5e794 | ||
|
|
0ada2371b6 | ||
|
|
8bc1e956cf | ||
|
|
a0673ef2b6 | ||
|
|
416f831a6c | ||
|
|
0b4317ce28 | ||
|
|
12e2481acb | ||
|
|
270709064d | ||
|
|
0830ef3305 | ||
|
|
722cc174b7 | ||
|
|
97c18d0c7f | ||
|
|
2223aeb022 | ||
|
|
4b1e83c42d | ||
|
|
ecf2f7f212 | ||
|
|
01fd8b53a6 | ||
|
|
e60f200192 | ||
|
|
033359e93c | ||
|
|
c41820541d | ||
|
|
228f0c5ee5 | ||
|
|
8a5e074f14 | ||
|
|
ac4262c542 | ||
|
|
400b2b0ed0 | ||
|
|
27b034674d | ||
|
|
1379d7f184 | ||
|
|
716bf6f48a | ||
|
|
2422eb2820 | ||
|
|
46e03683ce | ||
|
|
ff0985f06e | ||
|
|
a8ac8a25d5 | ||
|
|
5b2082ba58 | ||
|
|
967ccabb56 | ||
|
|
144513f1d8 | ||
|
|
e3087e9bea | ||
|
|
484a8595e4 | ||
|
|
c97e2875b4 | ||
|
|
64794630c8 | ||
|
|
fc5055c766 | ||
|
|
27eb358497 | ||
|
|
6810ee0a28 | ||
|
|
7c4d9d225e | ||
|
|
fefe5913e9 | ||
|
|
142d5fb209 | ||
|
|
f6ccd402e2 | ||
|
|
1c371300ab | ||
|
|
918690701d | ||
|
|
f0008e95fa | ||
|
|
7a249b206d | ||
|
|
0cc7f5cca6 | ||
|
|
ed86ec8b59 | ||
|
|
895ee09b33 | ||
|
|
61006bed9e | ||
|
|
58591b8c2a | ||
|
|
69b9950aa0 | ||
|
|
bcf1b160d1 | ||
|
|
1e0e40aa69 | ||
|
|
74392efbed | ||
|
|
cc020d6a40 | ||
|
|
daa1741aed | ||
|
|
a25bcaa58f | ||
|
|
d34b601dae | ||
|
|
9932962320 | ||
|
|
00d6cda9ed | ||
|
|
5ffb520363 | ||
|
|
e0f80cdb8f | ||
|
|
7a7a923504 | ||
|
|
4f6c171a08 | ||
|
|
310f8c247e | ||
|
|
811019bf5c | ||
|
|
0ed6600437 | ||
|
|
1fe7f14d57 | ||
|
|
a7bafec1bf | ||
|
|
ed951b3974 | ||
|
|
c74e43b8fd | ||
|
|
03bd9b0cc4 | ||
|
|
149902bd8a | ||
|
|
cd3ed22045 | ||
|
|
2329d387ca | ||
|
|
7d18a8e2a9 | ||
|
|
80af3718d0 | ||
|
|
77ea6bec46 | ||
|
|
c0ab8ae953 | ||
|
|
923c2dee32 | ||
|
|
ea17a46d8e | ||
|
|
bfe9e5d25a | ||
|
|
831ff47254 | ||
|
|
11eaba6b5d | ||
|
|
c2e4ec25c8 | ||
|
|
8537f10412 | ||
|
|
71d60eeef7 | ||
|
|
247ae0988f | ||
|
|
0907fa6994 | ||
|
|
9855343aa8 | ||
|
|
bd50fde268 | ||
|
|
4267de5642 | ||
|
|
8b55116563 | ||
|
|
f35e63e3f3 | ||
|
|
17c409de23 | ||
|
|
e4753e7411 | ||
|
|
9adefa80b9 | ||
|
|
4ce2381182 | ||
|
|
62afc21ea5 | ||
|
|
7ddb7c586d |
6
.github/ISSUE_TEMPLATE/config.yml
vendored
6
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: 项目群聊
|
||||
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
|
||||
about: QQ 群:629454374
|
||||
- name: 交流社区
|
||||
url: https://linux.do
|
||||
about: 项目交流社区
|
||||
|
||||
5
.github/workflows/docker-image-amd64.yml
vendored
5
.github/workflows/docker-image-amd64.yml
vendored
@@ -1,9 +1,6 @@
|
||||
name: Publish Docker image (amd64)
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
@@ -42,7 +39,7 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
pengzhile/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
|
||||
2
.github/workflows/docker-image-arm64.yml
vendored
2
.github/workflows/docker-image-arm64.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
pengzhile/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
|
||||
@@ -7,7 +7,7 @@ all: build-frontend start-backend
|
||||
|
||||
build-frontend:
|
||||
@echo "Building frontend..."
|
||||
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||
@cd $(FRONTEND_DIR) && yarn install --network-timeout 1000000 && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) yarn build
|
||||
|
||||
start-backend:
|
||||
@echo "Starting backend dev server..."
|
||||
26
README.md
26
README.md
@@ -1,6 +1,13 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||
|
||||
# New API
|
||||
|
||||
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
</div>
|
||||
|
||||
> [!NOTE]
|
||||
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
|
||||
|
||||
@@ -44,7 +51,7 @@
|
||||
|
||||
## 模型支持
|
||||
此版本额外支持以下模型:
|
||||
1. 第三方模型 **gps** (gpt-4-gizmo-*)
|
||||
1. 第三方模型 **gps** (gpt-4-gizmo-*, g-*)
|
||||
2. 智谱glm-4v,glm-4v识图
|
||||
3. Anthropic Claude 3
|
||||
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
||||
@@ -54,10 +61,12 @@
|
||||
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
|
||||
9. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/),[对接文档](Rerank.md)
|
||||
10. Dify
|
||||
11. Vertex AI,目前兼容Claude,Gemini,Llama3.1
|
||||
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
|
||||
## 比原版One API多出的配置
|
||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`。
|
||||
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒。
|
||||
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`。
|
||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,请求上游返回流模式usage,默认为 `true`,建议开启,不影响客户端传入stream_options参数返回结果。
|
||||
@@ -65,7 +74,7 @@
|
||||
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。
|
||||
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
|
||||
- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
|
||||
|
||||
- `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`,`STRICT`,默认为 `NONE`。
|
||||
## 部署
|
||||
### 部署要求
|
||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||
@@ -114,22 +123,13 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||
## Suno接口设置文档
|
||||
[对接文档](Suno.md)
|
||||
|
||||
## 交流群
|
||||
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
|
||||
|
||||
## 界面截图
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
夜间模式
|
||||

|
||||
|
||||

|
||||

|
||||
|
||||
## 相关项目
|
||||
|
||||
@@ -9,9 +9,20 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Pay Settings
|
||||
|
||||
var StripeApiSecret = ""
|
||||
var StripeWebhookSecret = ""
|
||||
var StripePriceId = ""
|
||||
var PaymentEnabled = false
|
||||
var StripeUnitPrice = 8.0
|
||||
var MinTopUp = 5
|
||||
|
||||
var StartTime = time.Now().Unix() // unit: second
|
||||
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
||||
var SystemName = "New API"
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var OutProxyUrl = ""
|
||||
var Footer = ""
|
||||
var Logo = ""
|
||||
var TopUpLink = ""
|
||||
@@ -41,10 +52,12 @@ var PasswordLoginEnabled = true
|
||||
var PasswordRegisterEnabled = true
|
||||
var EmailVerificationEnabled = false
|
||||
var GitHubOAuthEnabled = false
|
||||
var LinuxDoOAuthEnabled = false
|
||||
var WeChatAuthEnabled = false
|
||||
var TelegramOAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
var RegisterEnabled = true
|
||||
var UserSelfDeletionEnabled = false
|
||||
|
||||
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
|
||||
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
|
||||
@@ -75,6 +88,10 @@ var SMTPToken = ""
|
||||
var GitHubClientId = ""
|
||||
var GitHubClientSecret = ""
|
||||
|
||||
var LinuxDoClientId = ""
|
||||
var LinuxDoClientSecret = ""
|
||||
var LinuxDoMinLevel = 0
|
||||
|
||||
var WeChatServerAddress = ""
|
||||
var WeChatServerToken = ""
|
||||
var WeChatAccountQRCodeImageURL = ""
|
||||
@@ -112,6 +129,9 @@ var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
||||
|
||||
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
|
||||
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
|
||||
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
|
||||
|
||||
const (
|
||||
RequestIdKey = "X-Oneapi-Request-Id"
|
||||
)
|
||||
@@ -176,6 +196,12 @@ const (
|
||||
ChannelStatusAutoDisabled = 3
|
||||
)
|
||||
|
||||
const (
|
||||
TopUpStatusPending = "pending"
|
||||
TopUpStatusSuccess = "success"
|
||||
TopUpStatusExpired = "expired"
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
@@ -213,6 +239,8 @@ const (
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
@@ -259,4 +287,6 @@ var ChannelBaseURLs = []string{
|
||||
"", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"errors"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type outlookAuth struct {
|
||||
@@ -30,3 +31,10 @@ func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func isOutlookServer(server string) bool {
|
||||
// 兼容多地区的outlook邮箱和ofb邮箱
|
||||
// 其实应该加一个Option来区分是否用LOGIN的方式登录
|
||||
// 先临时兼容一下
|
||||
return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func generateMessageID() string {
|
||||
domain := strings.Split(SMTPFrom, "@")[1]
|
||||
domain := strings.Split(SMTPAccount, "@")[1]
|
||||
return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,9 @@ func SendEmail(subject string, receiver string, content string) error {
|
||||
if SMTPFrom == "" { // for compatibility
|
||||
SMTPFrom = SMTPAccount
|
||||
}
|
||||
if SMTPServer == "" && SMTPAccount == "" {
|
||||
return fmt.Errorf("SMTP 服务器未配置")
|
||||
}
|
||||
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
|
||||
mail := []byte(fmt.Sprintf("To: %s\r\n"+
|
||||
"From: %s<%s>\r\n"+
|
||||
@@ -68,7 +71,7 @@ func SendEmail(subject string, receiver string, content string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if strings.HasSuffix(SMTPAccount, "outlook.com") {
|
||||
} else if isOutlookServer(SMTPAccount) {
|
||||
auth = LoginAuth(SMTPAccount, SMTPToken)
|
||||
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
|
||||
} else {
|
||||
|
||||
84
common/hash.go
Normal file
84
common/hash.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Sha256Raw(data string) []byte {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(data))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1Raw(data []byte) []byte {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(data))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1(data string) string {
|
||||
return hex.EncodeToString(Sha1Raw([]byte(data)))
|
||||
}
|
||||
|
||||
func HmacSha256Raw(message, key []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(message)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func HmacSha256(message, key string) string {
|
||||
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
|
||||
}
|
||||
|
||||
func RandomBytes(length int) []byte {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
b := make([]byte, length)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func RandomString(length int) string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, length)
|
||||
randomBytes := RandomBytes(length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomHex(length int) string {
|
||||
const chars = "abcdef0123456789"
|
||||
result := make([]byte, length)
|
||||
randomBytes := RandomBytes(length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomNumber(length int) string {
|
||||
const chars = "0123456789"
|
||||
result := make([]byte, length)
|
||||
randomBytes := RandomBytes(length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomUUID() string {
|
||||
all := RandomHex(32)
|
||||
return all[:8] + "-" + all[8:12] + "-" + all[12:16] + "-" + all[16:20] + "-" + all[20:]
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"golang.org/x/image/webp"
|
||||
"image"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -31,9 +30,24 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
|
||||
return config, format, base64String, err
|
||||
}
|
||||
|
||||
func IsImageUrl(url string) (bool, error) {
|
||||
resp, err := ProxiedHttpHead(url, OutProxyUrl)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
resp, err := DoImageRequest(url)
|
||||
isImage, err := IsImageUrl(url)
|
||||
if !isImage {
|
||||
return
|
||||
}
|
||||
resp, err := ProxiedHttpGet(url, OutProxyUrl)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -52,9 +66,9 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
}
|
||||
|
||||
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
response, err := DoImageRequest(imageUrl)
|
||||
response, err := ProxiedHttpGet(imageUrl, OutProxyUrl)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
@@ -66,7 +80,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
|
||||
var readData []byte
|
||||
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
|
||||
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
||||
SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
||||
|
||||
// 从response.Body读取更多的数据直到达到当前的限制
|
||||
additionalData := make([]byte, limit-int64(len(readData)))
|
||||
@@ -92,11 +106,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
common.SysLog(err.Error())
|
||||
SysLog(err.Error())
|
||||
config, err = webp.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
common.SysLog(err.Error())
|
||||
SysLog(err.Error())
|
||||
}
|
||||
format = "webp"
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
@@ -100,12 +99,10 @@ func LogQuota(quota int) string {
|
||||
}
|
||||
}
|
||||
|
||||
// LogJson 仅供测试使用 only for test
|
||||
func LogJson(ctx context.Context, msg string, obj any) {
|
||||
jsonStr, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
||||
return
|
||||
func LogQuotaF(quota float64) string {
|
||||
if DisplayInCurrencyEnabled {
|
||||
return fmt.Sprintf("$%.6f 额度", quota/QuotaPerUnit)
|
||||
} else {
|
||||
return fmt.Sprintf("%d 点额度", int64(quota))
|
||||
}
|
||||
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
}
|
||||
|
||||
@@ -23,44 +23,46 @@ const (
|
||||
|
||||
var defaultModelRatio = map[string]float64{
|
||||
//"midjourney": 50,
|
||||
"gpt-4-gizmo-*": 15,
|
||||
"gpt-4o-gizmo-*": 2.5,
|
||||
"gpt-4-all": 15,
|
||||
"gpt-4o-all": 15,
|
||||
"gpt-4": 15,
|
||||
//"gpt-4-0314": 15, //deprecated
|
||||
"gpt-4-0613": 15,
|
||||
"gpt-4-32k": 30,
|
||||
//"gpt-4-32k-0314": 30, //deprecated
|
||||
"gpt-4-32k-0613": 30,
|
||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
|
||||
"gpt-4o": 2.5, // $0.01 / 1K tokens
|
||||
"gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens
|
||||
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
|
||||
"gpt-4o-mini": 0.075,
|
||||
"gpt-4o-mini-2024-07-18": 0.075,
|
||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens
|
||||
//"gpt-3.5-turbo-0301": 0.75, //deprecated
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||
"gpt-3.5-turbo-0125": 0.25,
|
||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
||||
"davinci-002": 1, // $0.002 / 1K tokens
|
||||
"text-ada-001": 0.2,
|
||||
"text-babbage-001": 0.25,
|
||||
"text-curie-001": 1,
|
||||
//"text-davinci-002": 10,
|
||||
//"text-davinci-003": 10,
|
||||
"gpt-4-gizmo-*": 15,
|
||||
"g-*": 15,
|
||||
"gpt-4": 15,
|
||||
"gpt-4-0314": 15,
|
||||
"gpt-4-0613": 15,
|
||||
"gpt-4-32k": 30,
|
||||
"gpt-4-32k-0314": 30,
|
||||
"gpt-4-32k-0613": 30,
|
||||
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
|
||||
"gpt-4o-mini-2024-07-18": 0.075,
|
||||
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
|
||||
"gpt-4o": 2.5, // $0.005 / 1K tokens
|
||||
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
|
||||
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
|
||||
"o1-preview": 7.5,
|
||||
"o1-preview-2024-09-12": 7.5,
|
||||
"o1-mini": 1.5,
|
||||
"o1-mini-2024-09-12": 1.5,
|
||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
|
||||
"gpt-3.5-turbo-0301": 0.75,
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||
"gpt-3.5-turbo-0125": 0.25,
|
||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
||||
"davinci-002": 1, // $0.002 / 1K tokens
|
||||
"text-ada-001": 0.2,
|
||||
"text-babbage-001": 0.25,
|
||||
"text-curie-001": 1,
|
||||
"text-davinci-002": 10,
|
||||
"text-davinci-003": 10,
|
||||
"text-davinci-edit-001": 10,
|
||||
"code-davinci-edit-001": 10,
|
||||
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
||||
@@ -82,9 +84,9 @@ var defaultModelRatio = map[string]float64{
|
||||
"claude-2.0": 4, // $8 / 1M tokens
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
|
||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||
"claude-3-5-sonnet-20240620": 1.5,
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"ERNIE-4.0-8K": 0.120 * RMB,
|
||||
"ERNIE-3.5-8K": 0.012 * RMB,
|
||||
"ERNIE-3.5-8K-0205": 0.024 * RMB,
|
||||
@@ -106,8 +108,10 @@ var defaultModelRatio = map[string]float64{
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-1.0-pro-vision-001": 1,
|
||||
"gemini-1.0-pro-001": 1,
|
||||
"gemini-1.5-pro-latest": 1,
|
||||
"gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens
|
||||
"gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens
|
||||
"gemini-1.5-flash-latest": 1,
|
||||
"gemini-1.5-flash-exp-0827": 1,
|
||||
"gemini-1.0-pro-latest": 1,
|
||||
"gemini-1.0-pro-vision-latest": 1,
|
||||
"gemini-ultra": 1,
|
||||
@@ -119,6 +123,13 @@ var defaultModelRatio = map[string]float64{
|
||||
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
|
||||
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
|
||||
"glm-3-turbo": 0.3572,
|
||||
"glm-4-plus": 0.05 * RMB,
|
||||
"glm-4-0520": 0.1 * RMB,
|
||||
"glm-4-air": 0.001 * RMB,
|
||||
"glm-4-airx": 0.01 * RMB,
|
||||
"glm-4-long": 0.001 * RMB,
|
||||
"glm-4-flash": 0,
|
||||
"glm-4v-plus": 0.01 * RMB,
|
||||
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
|
||||
"qwen-plus": 10, // ¥0.14 / 1k tokens
|
||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||
@@ -137,26 +148,28 @@ var defaultModelRatio = map[string]float64{
|
||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||
// https://platform.lingyiwanwu.com/docs#-计费单元
|
||||
// 已经按照 7.2 来换算美元价格
|
||||
"yi-34b-chat-0205": 0.18,
|
||||
"yi-34b-chat-200k": 0.864,
|
||||
"yi-vl-plus": 0.432,
|
||||
"yi-large": 20.0 / 1000 * RMB,
|
||||
"yi-medium": 2.5 / 1000 * RMB,
|
||||
"yi-vision": 6.0 / 1000 * RMB,
|
||||
"yi-medium-200k": 12.0 / 1000 * RMB,
|
||||
"yi-spark": 1.0 / 1000 * RMB,
|
||||
"yi-large-rag": 25.0 / 1000 * RMB,
|
||||
"yi-large-turbo": 12.0 / 1000 * RMB,
|
||||
"yi-large-preview": 20.0 / 1000 * RMB,
|
||||
"yi-large-rag-preview": 25.0 / 1000 * RMB,
|
||||
"command": 0.5,
|
||||
"command-nightly": 0.5,
|
||||
"command-light": 0.5,
|
||||
"command-light-nightly": 0.5,
|
||||
"command-r": 0.25,
|
||||
"command-r-plus ": 1.5,
|
||||
"deepseek-chat": 0.07,
|
||||
"deepseek-coder": 0.07,
|
||||
"yi-34b-chat-0205": 0.18,
|
||||
"yi-34b-chat-200k": 0.864,
|
||||
"yi-vl-plus": 0.432,
|
||||
"yi-large": 20.0 / 1000 * RMB,
|
||||
"yi-medium": 2.5 / 1000 * RMB,
|
||||
"yi-vision": 6.0 / 1000 * RMB,
|
||||
"yi-medium-200k": 12.0 / 1000 * RMB,
|
||||
"yi-spark": 1.0 / 1000 * RMB,
|
||||
"yi-large-rag": 25.0 / 1000 * RMB,
|
||||
"yi-large-turbo": 12.0 / 1000 * RMB,
|
||||
"yi-large-preview": 20.0 / 1000 * RMB,
|
||||
"yi-large-rag-preview": 25.0 / 1000 * RMB,
|
||||
"command": 0.5,
|
||||
"command-nightly": 0.5,
|
||||
"command-light": 0.5,
|
||||
"command-light-nightly": 0.5,
|
||||
"command-r": 0.25,
|
||||
"command-r-plus": 1.5,
|
||||
"command-r-08-2024": 0.075,
|
||||
"command-r-plus-08-2024": 1.25,
|
||||
"deepseek-chat": 0.07,
|
||||
"deepseek-coder": 0.07,
|
||||
// Perplexity online 模型对搜索额外收费,有需要应自行调整,此处不计入搜索费用
|
||||
"llama-3-sonar-small-32k-chat": 0.2 / 1000 * USD,
|
||||
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
|
||||
@@ -167,8 +180,10 @@ var defaultModelRatio = map[string]float64{
|
||||
var defaultModelPrice = map[string]float64{
|
||||
"suno_music": 0.1,
|
||||
"suno_lyrics": 0.01,
|
||||
"dall-e-2": 0.02,
|
||||
"dall-e-3": 0.04,
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"g-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
@@ -198,9 +213,10 @@ var (
|
||||
|
||||
var CompletionRatio map[string]float64 = nil
|
||||
var defaultCompletionRatio = map[string]float64{
|
||||
"gpt-4-gizmo-*": 2,
|
||||
"gpt-4o-gizmo-*": 3,
|
||||
"gpt-4-all": 2,
|
||||
"gpt-4-gizmo-*": 2,
|
||||
"g-*": 2,
|
||||
"gpt-4-all": 2,
|
||||
"gpt-4o-all": 2,
|
||||
}
|
||||
|
||||
func GetModelPriceMap() map[string]float64 {
|
||||
@@ -233,9 +249,8 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
|
||||
GetModelPriceMap()
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4o-gizmo") {
|
||||
name = "gpt-4o-gizmo-*"
|
||||
} else if strings.HasPrefix(name, "g-") {
|
||||
name = "g-*"
|
||||
}
|
||||
price, ok := modelPriceMap[name]
|
||||
if !ok {
|
||||
@@ -276,6 +291,8 @@ func GetModelRatio(name string) float64 {
|
||||
GetModelRatioMap()
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
} else if strings.HasPrefix(name, "g-") {
|
||||
name = "g-*"
|
||||
}
|
||||
ratio, ok := modelRatioMap[name]
|
||||
if !ok {
|
||||
@@ -316,47 +333,52 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
||||
func GetCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4o-gizmo") {
|
||||
name = "gpt-4o-gizmo-*"
|
||||
} else if strings.HasPrefix(name, "g-") {
|
||||
name = "g-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-3.5") {
|
||||
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
|
||||
// https://openai.com/blog/new-embedding-models-and-api-updates
|
||||
// Updated GPT-3.5 Turbo model and lower pricing
|
||||
if strings.HasSuffix(name, "0125") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasSuffix(name, "1106") {
|
||||
return 2
|
||||
}
|
||||
return 4.0 / 3.0
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
|
||||
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
|
||||
if name == "gpt-3.5-turbo" {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4o") {
|
||||
if strings.HasPrefix(name, "gpt-4o-mini") || name == "gpt-4o-2024-08-06" {
|
||||
return 4
|
||||
}
|
||||
|
||||
return 4.0 / 3.0
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" {
|
||||
if strings.HasPrefix(name, "gpt-4o-mini") || "gpt-4o-2024-08-06" == name {
|
||||
return 4
|
||||
}
|
||||
|
||||
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || strings.HasPrefix(name, "gpt-4o") {
|
||||
return 3
|
||||
}
|
||||
return 2
|
||||
}
|
||||
if "o1" == name || strings.HasPrefix(name, "o1-") {
|
||||
return 4
|
||||
}
|
||||
if name == "chatgpt-4o-latest" {
|
||||
return 3
|
||||
}
|
||||
if strings.Contains(name, "claude-instant-1") {
|
||||
if strings.HasPrefix(name, "claude-instant-1") {
|
||||
return 3
|
||||
} else if strings.Contains(name, "claude-2") {
|
||||
} else if strings.HasPrefix(name, "claude-2") {
|
||||
return 3
|
||||
} else if strings.Contains(name, "claude-3") {
|
||||
} else if strings.HasPrefix(name, "claude-3") {
|
||||
return 5
|
||||
}
|
||||
if strings.HasPrefix(name, "mistral-") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "gemini-") {
|
||||
if strings.Contains(name, "flash") {
|
||||
return 4
|
||||
}
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "command") {
|
||||
@@ -365,6 +387,10 @@ func GetCompletionRatio(name string) float64 {
|
||||
return 3
|
||||
case "command-r-plus":
|
||||
return 5
|
||||
case "command-r-08-2024":
|
||||
return 4
|
||||
case "command-r-plus-08-2024":
|
||||
return 4
|
||||
default:
|
||||
return 2
|
||||
}
|
||||
|
||||
@@ -31,14 +31,6 @@ func MapToJsonStr(m map[string]interface{}) string {
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func MapToJsonStrFloat(m map[string]float64) string {
|
||||
bytes, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func StrToMap(str string) map[string]interface{} {
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(str), &m)
|
||||
@@ -48,6 +40,11 @@ func StrToMap(str string) map[string]interface{} {
|
||||
return m
|
||||
}
|
||||
|
||||
func IsJsonStr(str string) bool {
|
||||
var js map[string]interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
|
||||
func String2Int(str string) int {
|
||||
num, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
|
||||
23
common/user_groups.go
Normal file
23
common/user_groups.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
var UserUsableGroups = map[string]string{
|
||||
"default": "默认分组",
|
||||
"vip": "vip分组",
|
||||
}
|
||||
|
||||
func UserUsableGroups2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(UserUsableGroups)
|
||||
if err != nil {
|
||||
SysError("error marshalling user groups: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
|
||||
UserUsableGroups = make(map[string]string)
|
||||
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
|
||||
}
|
||||
@@ -1,12 +1,17 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/net/proxy"
|
||||
"html/template"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -128,6 +133,11 @@ func IntMax(a int, b int) int {
|
||||
}
|
||||
}
|
||||
|
||||
func IsIP(s string) bool {
|
||||
ip := net.ParseIP(s)
|
||||
return ip != nil
|
||||
}
|
||||
|
||||
func GetUUID() string {
|
||||
code := uuid.New().String()
|
||||
code = strings.Replace(code, "-", "", -1)
|
||||
@@ -187,3 +197,56 @@ func RandomSleep() {
|
||||
// Sleep for 0-3000 ms
|
||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
func GetProxiedHttpClient(proxyUrl string) (*http.Client, error) {
|
||||
if "" == proxyUrl {
|
||||
return &http.Client{}, nil
|
||||
}
|
||||
|
||||
u, err := url.Parse(proxyUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.HasPrefix(proxyUrl, "http") {
|
||||
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(u),
|
||||
},
|
||||
}, nil
|
||||
} else if strings.HasPrefix(proxyUrl, "socks") {
|
||||
dialer, err := proxy.FromURL(u, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported proxy type")
|
||||
}
|
||||
|
||||
func ProxiedHttpGet(url, proxyUrl string) (*http.Response, error) {
|
||||
client, err := GetProxiedHttpClient(proxyUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.Get(url)
|
||||
}
|
||||
|
||||
func ProxiedHttpHead(url, proxyUrl string) (*http.Response, error) {
|
||||
client, err := GetProxiedHttpClient(proxyUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.Head(url)
|
||||
}
|
||||
|
||||
@@ -20,14 +20,16 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR
|
||||
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
|
||||
var GeminiModelMap = map[string]string{
|
||||
"gemini-1.5-pro-latest": "v1beta",
|
||||
"gemini-1.5-pro-001": "v1beta",
|
||||
"gemini-1.5-pro": "v1beta",
|
||||
"gemini-1.5-pro-exp-0801": "v1beta",
|
||||
"gemini-1.5-flash-latest": "v1beta",
|
||||
"gemini-1.5-flash-001": "v1beta",
|
||||
"gemini-1.5-flash": "v1beta",
|
||||
"gemini-ultra": "v1beta",
|
||||
"gemini-1.5-pro-latest": "v1beta",
|
||||
"gemini-1.5-pro-001": "v1beta",
|
||||
"gemini-1.5-pro": "v1beta",
|
||||
"gemini-1.5-pro-exp-0801": "v1beta",
|
||||
"gemini-1.5-pro-exp-0827": "v1beta",
|
||||
"gemini-1.5-flash-latest": "v1beta",
|
||||
"gemini-1.5-flash-exp-0827": "v1beta",
|
||||
"gemini-1.5-flash-001": "v1beta",
|
||||
"gemini-1.5-flash": "v1beta",
|
||||
"gemini-ultra": "v1beta",
|
||||
}
|
||||
|
||||
func InitEnv() {
|
||||
@@ -44,3 +46,6 @@ func InitEnv() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 是否生成初始令牌,默认关闭。
|
||||
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package constant
|
||||
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var WorkerUrl = ""
|
||||
var WorkerValidKey = ""
|
||||
|
||||
func EnableWorker() bool {
|
||||
return WorkerUrl != ""
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -81,8 +82,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
|
||||
request := buildTestRequest()
|
||||
request.Model = testModel
|
||||
request := buildTestRequest(testModel)
|
||||
meta.UpstreamModelName = testModel
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||
|
||||
@@ -141,17 +141,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func buildTestRequest() *dto.GeneralOpenAIRequest {
|
||||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
testRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: "", // this will be set later
|
||||
MaxTokens: 1,
|
||||
Stream: false,
|
||||
Model: "", // this will be set later
|
||||
Stream: false,
|
||||
}
|
||||
if "o1" == model || strings.HasPrefix(model, "o1-") {
|
||||
testRequest.MaxCompletionTokens = 1
|
||||
} else {
|
||||
testRequest.MaxTokens = 1
|
||||
}
|
||||
content, _ := json.Marshal("hi")
|
||||
testMessage := dto.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
}
|
||||
testRequest.Model = model
|
||||
testRequest.Messages = append(testRequest.Messages, testMessage)
|
||||
return testRequest
|
||||
}
|
||||
@@ -226,26 +231,22 @@ func testAllChannels(notify bool) error {
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
ban := false
|
||||
if milliseconds > disableThreshold {
|
||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
ban = true
|
||||
}
|
||||
shouldBanChannel := false
|
||||
|
||||
// request error disables the channel
|
||||
if openaiWithStatusErr != nil {
|
||||
oaiErr := openaiWithStatusErr.Error
|
||||
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
||||
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
||||
}
|
||||
|
||||
// parse *int to bool
|
||||
if !channel.GetAutoBan() {
|
||||
ban = false
|
||||
if milliseconds > disableThreshold {
|
||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
shouldBanChannel = true
|
||||
}
|
||||
|
||||
// disable channel
|
||||
if ban && isChannelEnabled {
|
||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
|
||||
|
||||
@@ -198,6 +198,28 @@ func AddChannel(c *gin.Context) {
|
||||
}
|
||||
channel.CreatedTime = common.GetTimestamp()
|
||||
keys := strings.Split(channel.Key, "\n")
|
||||
if channel.Type == common.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区不能为空",
|
||||
})
|
||||
return
|
||||
} else {
|
||||
if common.IsJsonStr(channel.Other) {
|
||||
// must have default
|
||||
regionMap := common.StrToMap(channel.Other)
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
keys = []string{channel.Key}
|
||||
}
|
||||
channels := make([]model.Channel, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
@@ -297,6 +319,27 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type == common.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区不能为空",
|
||||
})
|
||||
return
|
||||
} else {
|
||||
if common.IsJsonStr(channel.Other) {
|
||||
// must have default
|
||||
regionMap := common.StrToMap(channel.Other)
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -123,6 +123,8 @@ func GitHubOAuth(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
|
||||
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if githubUser.Name != "" {
|
||||
user.DisplayName = githubUser.Name
|
||||
@@ -133,7 +135,7 @@ func GitHubOAuth(c *gin.Context) {
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(0); err != nil {
|
||||
if err := user.Insert(user.InviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
|
||||
@@ -17,3 +17,18 @@ func GetGroups(c *gin.Context) {
|
||||
"data": groupNames,
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserGroups(c *gin.Context) {
|
||||
usableGroups := make(map[string]string)
|
||||
for groupName, _ := range common.GroupRatio {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
if _, ok := common.UserUsableGroups[groupName]; ok {
|
||||
usableGroups[groupName] = common.UserUsableGroups[groupName]
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": usableGroups,
|
||||
})
|
||||
}
|
||||
|
||||
239
controller/linuxdo.go
Normal file
239
controller/linuxdo.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LinuxDoOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type LinuxDoUser struct {
|
||||
ID int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Authorization", "Basic "+auth)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse LinuxDoOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
var linuxdoUser LinuxDoUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if linuxdoUser.ID == 0 {
|
||||
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||
}
|
||||
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
|
||||
return nil, errors.New("用户 LINUX DO 信任等级不足!")
|
||||
}
|
||||
return &linuxdoUser, nil
|
||||
}
|
||||
|
||||
func LinuxDoOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LinuxDoBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.LinuxDoOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 LINUX DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
|
||||
LinuxDoLevel: linuxdoUser.TrustLevel,
|
||||
}
|
||||
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
|
||||
err := user.FillUserByLinuxDoId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDoLevel = linuxdoUser.TrustLevel
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
affCode := c.Query("aff")
|
||||
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
|
||||
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if linuxdoUser.Name != "" {
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
} else {
|
||||
user.DisplayName = linuxdoUser.Username
|
||||
}
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(user.InviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func LinuxDoBind(c *gin.Context) {
|
||||
if !common.LinuxDoOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 LINUX DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
|
||||
LinuxDoLevel: linuxdoUser.TrustLevel,
|
||||
}
|
||||
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 LINUX DO 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
|
||||
user.LinuxDoLevel = linuxdoUser.TrustLevel
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -192,7 +192,7 @@ func DeleteHistoryLogs(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
count, err := model.DeleteOldLog(targetTimestamp)
|
||||
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -233,7 +233,7 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
}
|
||||
if constant.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
}
|
||||
}
|
||||
@@ -265,7 +265,7 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
}
|
||||
if constant.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,8 @@ func GetStatus(c *gin.Context) {
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDoClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
@@ -45,9 +47,9 @@ func GetStatus(c *gin.Context) {
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": constant.ServerAddress,
|
||||
"price": constant.Price,
|
||||
"min_topup": constant.MinTopUp,
|
||||
"server_address": common.ServerAddress,
|
||||
"stripe_unit_price": common.StripeUnitPrice,
|
||||
"min_topup": common.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
@@ -61,7 +63,7 @@ func GetStatus(c *gin.Context) {
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "",
|
||||
"payment_enabled": common.PaymentEnabled,
|
||||
"mj_notify_enabled": constant.MjNotifyEnabled,
|
||||
},
|
||||
})
|
||||
@@ -204,7 +206,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
|
||||
@@ -137,31 +137,63 @@ func init() {
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
models := model.GetGroupModels(user.Group)
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
permission := getPermission()
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
if modelLimitEnable {
|
||||
s, ok := c.Get("token_model_limit")
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: s,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: s,
|
||||
Parent: nil,
|
||||
tokenModelLimit = map[string]bool{}
|
||||
}
|
||||
for allowModel, _ := range tokenModelLimit {
|
||||
if _, ok := openAIModelsMap[allowModel]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: allowModel,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userGroup, err := model.GetUserGroup(userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "get user group failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
group := userGroup
|
||||
tokenGroup := c.GetString("token_group")
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
models := model.GetGroupModels(group)
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: s,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: s,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
|
||||
@@ -50,6 +50,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "LinuxDoOAuthEnabled":
|
||||
if option.Value == "true" && common.LinuxDoClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 LINUX DO OAuth,请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -7,18 +7,11 @@ import (
|
||||
)
|
||||
|
||||
func GetPricing(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
// if no login, get default group ratio
|
||||
groupRatio := common.GetGroupRatio("default")
|
||||
group, err := model.CacheGetUserGroup(userId)
|
||||
if err == nil {
|
||||
groupRatio = common.GetGroupRatio(group)
|
||||
}
|
||||
pricing := model.GetPricing(group)
|
||||
pricing := model.GetPricing()
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"group_ratio": groupRatio,
|
||||
"group_ratio": common.GroupRatio,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -121,6 +121,9 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
||||
if openaiErr == nil {
|
||||
return false
|
||||
}
|
||||
if openaiErr.LocalError {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
return false
|
||||
}
|
||||
@@ -151,9 +154,6 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
||||
// azure处理超时不重试
|
||||
return false
|
||||
}
|
||||
if openaiErr.LocalError {
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode/100 == 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
99
controller/stripe.go
Normal file
99
controller/stripe.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stripe/stripe-go/v79"
|
||||
"github.com/stripe/stripe-go/v79/webhook"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func StripeWebhook(c *gin.Context) {
|
||||
payload, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
signature := c.GetHeader("Stripe-Signature")
|
||||
endpointSecret := common.StripeWebhookSecret
|
||||
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
|
||||
IgnoreAPIVersionMismatch: true,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case stripe.EventTypeCheckoutSessionCompleted:
|
||||
sessionCompleted(event)
|
||||
case stripe.EventTypeCheckoutSessionExpired:
|
||||
sessionExpired(event)
|
||||
default:
|
||||
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func sessionCompleted(event stripe.Event) {
|
||||
customerId := event.GetObjectValue("customer")
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "complete" != status {
|
||||
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
err := model.Recharge(referenceId, customerId)
|
||||
if err != nil {
|
||||
log.Println(err.Error(), referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
||||
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
||||
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
||||
}
|
||||
|
||||
func sessionExpired(event stripe.Event) {
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "expired" != status {
|
||||
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if "" == referenceId {
|
||||
log.Println("未提供支付单号")
|
||||
return
|
||||
}
|
||||
|
||||
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||
if topUp == nil {
|
||||
log.Println("充值订单不存在", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
log.Println("充值订单状态错误", referenceId)
|
||||
}
|
||||
|
||||
topUp.Status = common.TopUpStatusExpired
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("充值订单已过期", referenceId)
|
||||
}
|
||||
@@ -134,6 +134,8 @@ func AddToken(c *gin.Context) {
|
||||
UnlimitedQuota: token.UnlimitedQuota,
|
||||
ModelLimitsEnabled: token.ModelLimitsEnabled,
|
||||
ModelLimits: token.ModelLimits,
|
||||
AllowIps: token.AllowIps,
|
||||
Group: token.Group,
|
||||
}
|
||||
err = cleanToken.Insert()
|
||||
if err != nil {
|
||||
@@ -221,6 +223,8 @@ func UpdateToken(c *gin.Context) {
|
||||
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
||||
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
|
||||
cleanToken.ModelLimits = token.ModelLimits
|
||||
cleanToken.AllowIps = token.AllowIps
|
||||
cleanToken.Group = token.Group
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
package controller
|
||||
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Calcium-Ion/go-epay/epay"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stripe/stripe-go/v79"
|
||||
"github.com/stripe/stripe-go/v79/checkout/session"
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"sync"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EpayRequest struct {
|
||||
type PayRequest struct {
|
||||
Amount int `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
@@ -27,201 +25,114 @@ type AmountRequest struct {
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" {
|
||||
return nil
|
||||
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
|
||||
if !strings.HasPrefix(common.StripeApiSecret, "sk_") {
|
||||
return "", fmt.Errorf("无效的Stripe API密钥")
|
||||
}
|
||||
withUrl, err := epay.NewClient(&epay.Config{
|
||||
PartnerID: constant.EpayId,
|
||||
Key: constant.EpayKey,
|
||||
}, constant.PayAddress)
|
||||
|
||||
stripe.Key = common.StripeApiSecret
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
ClientReferenceID: stripe.String(referenceId),
|
||||
SuccessURL: stripe.String(common.ServerAddress + "/log"),
|
||||
CancelURL: stripe.String(common.ServerAddress + "/topup"),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(common.StripePriceId),
|
||||
Quantity: stripe.Int64(amount),
|
||||
},
|
||||
},
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||
}
|
||||
|
||||
if "" == customerId {
|
||||
if "" != email {
|
||||
params.CustomerEmail = stripe.String(email)
|
||||
}
|
||||
|
||||
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
|
||||
} else {
|
||||
params.Customer = stripe.String(customerId)
|
||||
}
|
||||
|
||||
result, err := session.New(params)
|
||||
if err != nil {
|
||||
return nil
|
||||
return "", err
|
||||
}
|
||||
return withUrl
|
||||
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func getPayMoney(amount float64, group string) float64 {
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
// 别问为什么用float64,问就是这么点钱没必要
|
||||
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * constant.Price * topupGroupRatio
|
||||
return payMoney
|
||||
func GetPayAmount(count float64) float64 {
|
||||
return count * common.StripeUnitPrice
|
||||
}
|
||||
|
||||
func getMinTopup() int {
|
||||
minTopup := constant.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||
func GetChargedAmount(count float64, user model.User) float64 {
|
||||
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||
if topUpGroupRatio == 0 {
|
||||
topUpGroupRatio = 1
|
||||
}
|
||||
return minTopup
|
||||
|
||||
return count * topUpGroupRatio
|
||||
}
|
||||
|
||||
func RequestEpay(c *gin.Context) {
|
||||
var req EpayRequest
|
||||
func RequestPayLink(c *gin.Context) {
|
||||
var req PayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
|
||||
return
|
||||
}
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
if !common.PaymentEnabled {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
|
||||
return
|
||||
}
|
||||
if req.PaymentMethod != "stripe" {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||
return
|
||||
}
|
||||
if req.Amount < common.MinTopUp {
|
||||
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
|
||||
return
|
||||
}
|
||||
if req.Amount > 10000 {
|
||||
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
group, err := model.CacheGetUserGroup(id)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getPayMoney(float64(req.Amount), group)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
user, _ := model.GetUserById(id, false)
|
||||
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||
|
||||
var payType epay.PurchaseType
|
||||
if req.PaymentMethod == "zfb" {
|
||||
payType = epay.Alipay
|
||||
}
|
||||
if req.PaymentMethod == "wx" {
|
||||
req.PaymentMethod = "wxpay"
|
||||
payType = epay.WechatPay
|
||||
}
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
|
||||
return
|
||||
}
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
Type: payType,
|
||||
ServiceTradeNo: tradeNo,
|
||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
Device: epay.PC,
|
||||
NotifyUrl: notifyUrl,
|
||||
ReturnUrl: returnUrl,
|
||||
})
|
||||
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), common.RandomString(4))
|
||||
referenceId := "ref_" + common.Sha1(reference)
|
||||
|
||||
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, int64(req.Amount))
|
||||
if err != nil {
|
||||
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
amount := req.Amount
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / int(common.QuotaPerUnit)
|
||||
}
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: "pending",
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
|
||||
}
|
||||
|
||||
// tradeNo lock
|
||||
var orderLocks sync.Map
|
||||
var createLock sync.Mutex
|
||||
|
||||
// LockOrder 尝试对给定订单号加锁
|
||||
func LockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
createLock.Lock()
|
||||
defer createLock.Unlock()
|
||||
lock, ok = orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
lock = new(sync.Mutex)
|
||||
orderLocks.Store(tradeNo, lock)
|
||||
}
|
||||
}
|
||||
lock.(*sync.Mutex).Lock()
|
||||
}
|
||||
|
||||
// UnlockOrder 释放给定订单号的锁
|
||||
func UnlockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if ok {
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func EpayNotify(c *gin.Context) {
|
||||
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
|
||||
r[t] = c.Request.URL.Query().Get(t)
|
||||
return r
|
||||
}, map[string]string{})
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
log.Println("易支付回调失败 未找到配置信息")
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
return
|
||||
}
|
||||
}
|
||||
verifyInfo, err := client.Verify(params)
|
||||
if err == nil && verifyInfo.VerifyStatus {
|
||||
_, err := c.Writer.Write([]byte("success"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
} else {
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
log.Println("易支付回调签名验证失败")
|
||||
return
|
||||
}
|
||||
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
log.Println(verifyInfo)
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
|
||||
if topUp == nil {
|
||||
log.Printf("易支付回调未找到订单: %v", verifyInfo)
|
||||
return
|
||||
}
|
||||
if topUp.Status == "pending" {
|
||||
topUp.Status = "success"
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新订单失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
||||
//user.Quota += topUp.Amount * 500000
|
||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
|
||||
}
|
||||
} else {
|
||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"payLink": payLink,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func RequestAmount(c *gin.Context) {
|
||||
@@ -231,21 +142,23 @@ func RequestAmount(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
if !common.PaymentEnabled {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
|
||||
return
|
||||
}
|
||||
if req.Amount < common.MinTopUp {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
group, err := model.CacheGetUserGroup(id)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getPayMoney(float64(req.Amount), group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
user, _ := model.GetUserById(id, false)
|
||||
payMoney := GetPayAmount(float64(req.Amount))
|
||||
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"payAmount": strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
"chargedAmount": strconv.FormatFloat(chargedMoney, 'f', 2, 64),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/constant"
|
||||
)
|
||||
|
||||
type LoginRequest struct {
|
||||
@@ -66,6 +67,7 @@ func setupLogin(user *model.User, c *gin.Context) {
|
||||
session.Set("username", user.Username)
|
||||
session.Set("role", user.Role)
|
||||
session.Set("status", user.Status)
|
||||
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -186,6 +188,39 @@ func Register(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取插入后的用户ID
|
||||
var insertedUser model.User
|
||||
if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户注册失败或用户ID获取失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 生成默认令牌
|
||||
if constant.GenerateDefaultToken {
|
||||
// 生成默认令牌
|
||||
token := model.Token{
|
||||
UserId: insertedUser.Id, // 使用插入后的用户ID
|
||||
Name: cleanUser.Username + "的初始令牌",
|
||||
Key: common.GenerateKey(),
|
||||
CreatedTime: common.GetTimestamp(),
|
||||
AccessedTime: common.GetTimestamp(),
|
||||
ExpiredTime: -1, // 永不过期
|
||||
RemainQuota: 500000, // 示例额度
|
||||
UnlimitedQuota: true,
|
||||
ModelLimitsEnabled: false,
|
||||
}
|
||||
if err := token.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "创建默认令牌失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -517,7 +552,7 @@ func UpdateSelf(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteUser(c *gin.Context) {
|
||||
func HardDeleteUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -526,7 +561,7 @@ func DeleteUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
originUser, err := model.GetUserById(id, false)
|
||||
originUser, err := model.GetUserByIdUnscoped(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -550,9 +585,23 @@ func DeleteUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteSelf(c *gin.Context) {
|
||||
if !common.UserSelfDeletionEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "当前设置不允许用户自我删除账号",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
|
||||
|
||||
@@ -2,18 +2,17 @@ version: '3.4'
|
||||
|
||||
services:
|
||||
new-api:
|
||||
image: calciumion/new-api:latest
|
||||
# build: .
|
||||
image: pengzhile/new-api:latest
|
||||
container_name: new-api
|
||||
restart: always
|
||||
command: --log-dir /app/logs
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- ./data:/data
|
||||
- ./data/new-api:/data
|
||||
- ./logs:/app/logs
|
||||
environment:
|
||||
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- SESSION_SECRET=random_string # 修改为随机字符串
|
||||
- TZ=Asia/Shanghai
|
||||
@@ -23,13 +22,22 @@ services:
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
healthcheck:
|
||||
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
- db
|
||||
|
||||
redis:
|
||||
image: redis:latest
|
||||
container_name: redis
|
||||
restart: always
|
||||
|
||||
db:
|
||||
image: mysql:8.2.0
|
||||
container_name: mysql
|
||||
restart: always
|
||||
volumes:
|
||||
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
|
||||
environment:
|
||||
TZ: Asia/Shanghai # 设置时区
|
||||
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
|
||||
MYSQL_USER: newapi # 创建专用用户
|
||||
MYSQL_PASSWORD: '123456' # 设置专用用户密码
|
||||
MYSQL_DATABASE: new-api # 自动创建数据库
|
||||
@@ -1,14 +1,17 @@
|
||||
package dto
|
||||
|
||||
type RerankRequest struct {
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
ReturnDocuments bool `json:"return_documents,omitempty"`
|
||||
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
|
||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type RerankResponseDocument struct {
|
||||
Document any `json:"document"`
|
||||
Document any `json:"document,omitempty"`
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
}
|
||||
|
||||
@@ -2,36 +2,39 @@ package dto
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCall `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
BestOf int `json:"best_of,omitempty"`
|
||||
Echo bool `json:"echo,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
Suffix string `json:"suffix,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCall `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogitBias any `json:"logit_bias,omitempty"`
|
||||
LogProbs any `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_Tool_Calls,omitempty"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAITools struct {
|
||||
|
||||
@@ -34,6 +34,7 @@ type OpenAITextResponseChoice struct {
|
||||
|
||||
type OpenAITextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
@@ -41,9 +42,9 @@ type OpenAITextResponse struct {
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponse struct {
|
||||
|
||||
25
go.mod
25
go.mod
@@ -1,14 +1,16 @@
|
||||
module one-api
|
||||
|
||||
// +heroku goVersion go1.18
|
||||
go 1.18
|
||||
go 1.21
|
||||
|
||||
toolchain go1.22.4
|
||||
|
||||
require (
|
||||
github.com/Calcium-Ion/go-epay v0.0.2
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
||||
github.com/aws/aws-sdk-go-v2 v1.26.1
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||
github.com/gin-contrib/cors v1.4.0
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
@@ -24,8 +26,10 @@ require (
|
||||
github.com/pkoukk/tiktoken-go v0.1.7
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
golang.org/x/crypto v0.21.0
|
||||
github.com/stripe/stripe-go/v79 v79.12.0
|
||||
golang.org/x/crypto v0.26.0
|
||||
golang.org/x/image v0.15.0
|
||||
golang.org/x/net v0.28.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
@@ -38,9 +42,8 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||
github.com/aws/smithy-go v1.20.2 // indirect
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.11.0 // indirect
|
||||
@@ -51,6 +54,7 @@ require (
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.6.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
github.com/gorilla/sessions v1.2.1 // indirect
|
||||
@@ -65,10 +69,10 @@ require (
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/stretchr/testify v1.9.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
@@ -76,10 +80,9 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
golang.org/x/sync v0.8.0 // indirect
|
||||
golang.org/x/sys v0.24.0 // indirect
|
||||
golang.org/x/text v0.17.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
48
go.sum
48
go.sum
@@ -1,5 +1,3 @@
|
||||
github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
|
||||
github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||
@@ -23,8 +21,8 @@ github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaU
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
|
||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
@@ -37,6 +35,7 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
|
||||
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
|
||||
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g=
|
||||
@@ -57,6 +56,7 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
|
||||
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
@@ -81,7 +81,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq
|
||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -133,8 +134,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -142,8 +141,11 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
|
||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
@@ -172,7 +174,10 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stripe/stripe-go/v79 v79.12.0 h1:HQs/kxNEB3gYA7FnkSFkp0kSOeez0fsmCWev6SxftYs=
|
||||
github.com/stripe/stripe-go/v79 v79.12.0/go.mod h1:cuH6X0zC8peY6f1AubHwgJ/fJSn2dh5pfiCr6CjyKVU=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
@@ -191,21 +196,23 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
|
||||
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
|
||||
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -214,26 +221,27 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
|
||||
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
|
||||
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
|
||||
@@ -16,6 +16,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
role := session.Get("role")
|
||||
id := session.Get("id")
|
||||
status := session.Get("status")
|
||||
linuxDoEnable := session.Get("linuxdo_enable")
|
||||
useAccessToken := false
|
||||
if username == nil {
|
||||
// Check access token
|
||||
@@ -35,6 +36,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
role = user.Role
|
||||
id = user.Id
|
||||
status = user.Status
|
||||
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
|
||||
useAccessToken = true
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -83,6 +85,14 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户 LINUX DO 信任等级不足",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if role.(int) < minRole {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -162,6 +172,15 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !linuxDoEnabled {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
|
||||
return
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
@@ -175,6 +194,8 @@ func TokenAuth() func(c *gin.Context) {
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||
c.Set("token_group", token.Group)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
|
||||
@@ -22,6 +22,14 @@ type ModelRequest struct {
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
allowIpsMap := c.GetStringMap("allow_ips")
|
||||
if len(allowIpsMap) != 0 {
|
||||
clientIp := c.ClientIP()
|
||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
||||
return
|
||||
}
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
@@ -31,6 +39,20 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
tokenGroup := c.GetString("token_group")
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
if _, ok := common.UserUsableGroups[tokenGroup]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if _, ok := common.GroupRatio[tokenGroup]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
c.Set("group", userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
@@ -199,6 +221,8 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeVertexAi:
|
||||
c.Set("region", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
|
||||
@@ -36,6 +36,12 @@ func GetEnabledModels() []string {
|
||||
return models
|
||||
}
|
||||
|
||||
func GetAllEnableAbilities() []Ability {
|
||||
var abilities []Ability
|
||||
DB.Find(&abilities, "enabled = ?", true)
|
||||
return abilities
|
||||
}
|
||||
|
||||
func getPriority(group string, model string, retry int) (int, error) {
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
|
||||
@@ -205,6 +205,30 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
return userEnabled, err
|
||||
}
|
||||
|
||||
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
|
||||
if !common.RedisEnabled {
|
||||
return IsLinuxDoEnabled(userId)
|
||||
}
|
||||
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
|
||||
if err == nil {
|
||||
return enabled == "1", nil
|
||||
}
|
||||
|
||||
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
enabled = "0"
|
||||
if linuxDoEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set linuxdo enabled error: " + err.Error())
|
||||
}
|
||||
return linuxDoEnabled, err
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
@@ -269,9 +293,8 @@ func SyncChannelCache(frequency int) {
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(model, "gpt-4o-gizmo") {
|
||||
model = "gpt-4o-gizmo-*"
|
||||
} else if strings.HasPrefix(model, "g-") {
|
||||
model = "g-*"
|
||||
}
|
||||
|
||||
// if memory cache is disabled, get channel directly from database
|
||||
|
||||
24
model/log.go
24
model/log.go
@@ -247,7 +247,25 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
return token
|
||||
}
|
||||
|
||||
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
||||
result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
||||
return result.RowsAffected, result.Error
|
||||
func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) {
|
||||
var total int64 = 0
|
||||
|
||||
for {
|
||||
if nil != ctx.Err() {
|
||||
return total, ctx.Err()
|
||||
}
|
||||
|
||||
result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{})
|
||||
if nil != result.Error {
|
||||
return total, result.Error
|
||||
}
|
||||
|
||||
total += result.RowsAffected
|
||||
|
||||
if result.RowsAffected < int64(limit) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
@@ -31,10 +31,12 @@ func InitOptionMap() {
|
||||
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
|
||||
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
|
||||
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
|
||||
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
|
||||
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
|
||||
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
|
||||
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
||||
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
|
||||
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
|
||||
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
||||
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
|
||||
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
||||
@@ -60,17 +62,19 @@ func InitOptionMap() {
|
||||
common.OptionMap["SystemName"] = common.SystemName
|
||||
common.OptionMap["Logo"] = common.Logo
|
||||
common.OptionMap["ServerAddress"] = ""
|
||||
common.OptionMap["WorkerUrl"] = constant.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey
|
||||
common.OptionMap["PayAddress"] = ""
|
||||
common.OptionMap["CustomCallbackAddress"] = ""
|
||||
common.OptionMap["EpayId"] = ""
|
||||
common.OptionMap["EpayKey"] = ""
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
|
||||
common.OptionMap["OutProxyUrl"] = ""
|
||||
common.OptionMap["StripeApiSecret"] = common.StripeApiSecret
|
||||
common.OptionMap["StripeWebhookSecret"] = common.StripeWebhookSecret
|
||||
common.OptionMap["StripePriceId"] = common.StripePriceId
|
||||
common.OptionMap["PaymentEnabled"] = strconv.FormatBool(common.PaymentEnabled)
|
||||
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["LinuxDoClientId"] = ""
|
||||
common.OptionMap["LinuxDoClientSecret"] = ""
|
||||
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
common.OptionMap["TelegramBotName"] = ""
|
||||
common.OptionMap["WeChatServerAddress"] = ""
|
||||
@@ -86,6 +90,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
||||
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
common.OptionMap["ChatLink"] = common.ChatLink
|
||||
@@ -173,6 +178,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.EmailVerificationEnabled = boolValue
|
||||
case "GitHubOAuthEnabled":
|
||||
common.GitHubOAuthEnabled = boolValue
|
||||
case "LinuxDoOAuthEnabled":
|
||||
common.LinuxDoOAuthEnabled = boolValue
|
||||
case "WeChatAuthEnabled":
|
||||
common.WeChatAuthEnabled = boolValue
|
||||
case "TelegramOAuthEnabled":
|
||||
@@ -181,6 +188,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.TurnstileCheckEnabled = boolValue
|
||||
case "RegisterEnabled":
|
||||
common.RegisterEnabled = boolValue
|
||||
case "UserSelfDeletionEnabled":
|
||||
common.UserSelfDeletionEnabled = boolValue
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
common.EmailDomainRestrictionEnabled = boolValue
|
||||
case "EmailAliasRestrictionEnabled":
|
||||
@@ -240,29 +249,33 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SMTPToken":
|
||||
common.SMTPToken = value
|
||||
case "ServerAddress":
|
||||
constant.ServerAddress = value
|
||||
case "WorkerUrl":
|
||||
constant.WorkerUrl = value
|
||||
case "WorkerValidKey":
|
||||
constant.WorkerValidKey = value
|
||||
case "PayAddress":
|
||||
constant.PayAddress = value
|
||||
case "CustomCallbackAddress":
|
||||
constant.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
constant.EpayId = value
|
||||
case "EpayKey":
|
||||
constant.EpayKey = value
|
||||
case "Price":
|
||||
constant.Price, _ = strconv.ParseFloat(value, 64)
|
||||
common.ServerAddress = value
|
||||
case "OutProxyUrl":
|
||||
common.OutProxyUrl = value
|
||||
case "StripeApiSecret":
|
||||
common.StripeApiSecret = value
|
||||
case "StripeWebhookSecret":
|
||||
common.StripeWebhookSecret = value
|
||||
case "StripePriceId":
|
||||
common.StripePriceId = value
|
||||
case "PaymentEnabled":
|
||||
common.PaymentEnabled, _ = strconv.ParseBool(value)
|
||||
case "StripeUnitPrice":
|
||||
common.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "MinTopUp":
|
||||
constant.MinTopUp, _ = strconv.Atoi(value)
|
||||
common.MinTopUp, _ = strconv.Atoi(value)
|
||||
case "TopupGroupRatio":
|
||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||
case "GitHubClientId":
|
||||
common.GitHubClientId = value
|
||||
case "GitHubClientSecret":
|
||||
common.GitHubClientSecret = value
|
||||
case "LinuxDoClientId":
|
||||
common.LinuxDoClientId = value
|
||||
case "LinuxDoClientSecret":
|
||||
common.LinuxDoClientSecret = value
|
||||
case "LinuxDoMinLevel":
|
||||
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
|
||||
case "Footer":
|
||||
common.Footer = value
|
||||
case "SystemName":
|
||||
@@ -303,6 +316,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
err = common.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = common.UpdateGroupRatioByJSONString(value)
|
||||
case "UserUsableGroups":
|
||||
err = common.UpdateUserUsableGroupsByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = common.UpdateCompletionRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
|
||||
@@ -7,14 +7,13 @@ import (
|
||||
)
|
||||
|
||||
type Pricing struct {
|
||||
Available bool `json:"available"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_group,omitempty"`
|
||||
EnableGroup []string `json:"enable_groups,omitempty"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -23,40 +22,47 @@ var (
|
||||
updatePricingLock sync.Mutex
|
||||
)
|
||||
|
||||
func GetPricing(group string) []Pricing {
|
||||
func GetPricing() []Pricing {
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
updatePricing()
|
||||
}
|
||||
if group != "" {
|
||||
userPricingMap := make([]Pricing, 0)
|
||||
models := GetGroupModels(group)
|
||||
for _, pricing := range pricingMap {
|
||||
if !common.StringsContains(models, pricing.ModelName) {
|
||||
pricing.Available = false
|
||||
}
|
||||
userPricingMap = append(userPricingMap, pricing)
|
||||
}
|
||||
return userPricingMap
|
||||
}
|
||||
//if group != "" {
|
||||
// userPricingMap := make([]Pricing, 0)
|
||||
// models := GetGroupModels(group)
|
||||
// for _, pricing := range pricingMap {
|
||||
// if !common.StringsContains(models, pricing.ModelName) {
|
||||
// pricing.Available = false
|
||||
// }
|
||||
// userPricingMap = append(userPricingMap, pricing)
|
||||
// }
|
||||
// return userPricingMap
|
||||
//}
|
||||
return pricingMap
|
||||
}
|
||||
|
||||
func updatePricing() {
|
||||
//modelRatios := common.GetModelRatios()
|
||||
enabledModels := GetEnabledModels()
|
||||
allModels := make(map[string]int)
|
||||
for i, model := range enabledModels {
|
||||
allModels[model] = i
|
||||
enableAbilities := GetAllEnableAbilities()
|
||||
modelGroupsMap := make(map[string][]string)
|
||||
for _, ability := range enableAbilities {
|
||||
groups := modelGroupsMap[ability.Model]
|
||||
if groups == nil {
|
||||
groups = make([]string, 0)
|
||||
}
|
||||
if !common.StringsContains(groups, ability.Group) {
|
||||
groups = append(groups, ability.Group)
|
||||
}
|
||||
modelGroupsMap[ability.Model] = groups
|
||||
}
|
||||
|
||||
pricingMap = make([]Pricing, 0)
|
||||
for model, _ := range allModels {
|
||||
for model, groups := range modelGroupsMap {
|
||||
pricing := Pricing{
|
||||
Available: true,
|
||||
ModelName: model,
|
||||
ModelName: model,
|
||||
EnableGroup: groups,
|
||||
}
|
||||
modelPrice, findPrice := common.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@@ -23,10 +22,34 @@ type Token struct {
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
|
||||
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
|
||||
AllowIps *string `json:"allow_ips" gorm:"default:''"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
Group string `json:"group" gorm:"default:''"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func (token *Token) GetIpLimitsMap() map[string]any {
|
||||
// delete empty spaces
|
||||
//split with \n
|
||||
ipLimitsMap := make(map[string]any)
|
||||
if token.AllowIps == nil {
|
||||
return ipLimitsMap
|
||||
}
|
||||
cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
|
||||
if cleanIps == "" {
|
||||
return ipLimitsMap
|
||||
}
|
||||
ips := strings.Split(cleanIps, "\n")
|
||||
for _, ip := range ips {
|
||||
ip = strings.TrimSpace(ip)
|
||||
ip = strings.ReplaceAll(ip, ",", "")
|
||||
if common.IsIP(ip) {
|
||||
ipLimitsMap[ip] = true
|
||||
}
|
||||
}
|
||||
return ipLimitsMap
|
||||
}
|
||||
|
||||
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
||||
var tokens []*Token
|
||||
var err error
|
||||
@@ -130,7 +153,8 @@ func (token *Token) Insert() error {
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (token *Token) Update() error {
|
||||
var err error
|
||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error
|
||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
|
||||
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -294,7 +318,7 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
|
||||
prompt = "您的额度已用尽"
|
||||
}
|
||||
if email != "" {
|
||||
topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress)
|
||||
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
|
||||
err = common.SendEmail(prompt, email,
|
||||
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
||||
if err != nil {
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
Status string `json:"status"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (topUp *TopUp) Insert() error {
|
||||
@@ -41,3 +49,51 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
|
||||
}
|
||||
return topUp
|
||||
}
|
||||
|
||||
func Recharge(referenceId string, customerId string) (err error) {
|
||||
if referenceId == "" {
|
||||
return errors.New("未提供支付单号")
|
||||
}
|
||||
|
||||
var quota float64
|
||||
topUp := &TopUp{}
|
||||
|
||||
refCol := "`trade_no`"
|
||||
if common.UsingPostgreSQL {
|
||||
refCol = `"trade_no"`
|
||||
}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
|
||||
if err != nil {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
return errors.New("充值订单状态错误")
|
||||
}
|
||||
|
||||
topUp.CompleteTime = common.GetTimestamp()
|
||||
topUp.Status = common.TopUpStatusSuccess
|
||||
err = tx.Save(topUp).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
quota = topUp.Money * common.QuotaPerUnit
|
||||
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return errors.New("充值失败," + err.Error())
|
||||
}
|
||||
|
||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.LogQuotaF(quota), topUp.Amount))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ type User struct {
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
|
||||
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
@@ -35,6 +37,7 @@ type User struct {
|
||||
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
|
||||
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
|
||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||
StripeCustomer string `json:"stripe_customer" gorm:"column:stripe_customer;index"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
@@ -64,7 +67,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||
|
||||
func GetMaxUserId() int {
|
||||
var user User
|
||||
DB.Last(&user)
|
||||
DB.Unscoped().Last(&user)
|
||||
return user.Id
|
||||
}
|
||||
|
||||
@@ -119,6 +122,20 @@ func GetUserById(id int, selectAll bool) (*User, error) {
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func GetUserByIdUnscoped(id int, selectAll bool) (*User, error) {
|
||||
if id == 0 {
|
||||
return nil, errors.New("id 为空!")
|
||||
}
|
||||
user := User{Id: id}
|
||||
var err error = nil
|
||||
if selectAll {
|
||||
err = DB.Unscoped().First(&user, "id = ?", id).Error
|
||||
} else {
|
||||
err = DB.Unscoped().Omit("password").First(&user, "id = ?", id).Error
|
||||
}
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func GetUserIdByAffCode(affCode string) (int, error) {
|
||||
if affCode == "" {
|
||||
return 0, errors.New("affCode 为空!")
|
||||
@@ -331,6 +348,14 @@ func (user *User) FillUserByGitHubId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByLinuxDoId() error {
|
||||
if user.LinuxDoId == "" {
|
||||
return errors.New("LINUX DO id 为空!")
|
||||
}
|
||||
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByWeChatId() error {
|
||||
if user.WeChatId == "" {
|
||||
return errors.New("WeChat id 为空!")
|
||||
@@ -370,6 +395,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
|
||||
return DB.Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsUsernameAlreadyTaken(username string) bool {
|
||||
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
@@ -415,6 +444,18 @@ func IsUserEnabled(userId int) (bool, error) {
|
||||
return user.Status == common.UserStatusEnabled, nil
|
||||
}
|
||||
|
||||
func IsLinuxDoEnabled(userId int) (bool, error) {
|
||||
if userId == 0 {
|
||||
return false, errors.New("user id is empty")
|
||||
}
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
|
||||
}
|
||||
|
||||
func ValidateAccessToken(token string) (user *User) {
|
||||
if token == "" {
|
||||
return nil
|
||||
|
||||
@@ -36,8 +36,8 @@ type AliEmbeddingRequest struct {
|
||||
}
|
||||
|
||||
type AliEmbedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
Embedding any `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type AliEmbeddingResponse struct {
|
||||
|
||||
@@ -105,7 +105,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
||||
for _, data := range response.Output.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(data.Url)
|
||||
_, b64, err := common.GetImageFromUrl(data.Url)
|
||||
if err != nil {
|
||||
common.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
|
||||
@@ -50,9 +50,9 @@ type BaiduEmbeddingRequest struct {
|
||||
}
|
||||
|
||||
type BaiduEmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object"`
|
||||
Embedding any `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingResponse struct {
|
||||
|
||||
@@ -79,9 +79,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -12,6 +11,8 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func stopReasonClaude2OpenAI(reason string) string {
|
||||
@@ -108,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
}
|
||||
}
|
||||
formatMessages := make([]dto.Message, 0)
|
||||
var lastMessage *dto.Message
|
||||
lastMessage := dto.Message{
|
||||
Role: "tool",
|
||||
}
|
||||
for i, message := range textRequest.Messages {
|
||||
//if message.Role == "system" {
|
||||
// if i != 0 {
|
||||
// message.Role = "user"
|
||||
// }
|
||||
//}
|
||||
if message.Role == "" {
|
||||
textRequest.Messages[i].Role = "user"
|
||||
}
|
||||
@@ -122,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
}
|
||||
if lastMessage != nil && lastMessage.Role == message.Role {
|
||||
if message.Role == "tool" {
|
||||
fmtMessage.ToolCallId = message.ToolCallId
|
||||
}
|
||||
if message.Role == "assistant" && message.ToolCalls != nil {
|
||||
fmtMessage.ToolCalls = message.ToolCalls
|
||||
}
|
||||
if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
|
||||
if lastMessage.IsStringContent() && message.IsStringContent() {
|
||||
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
|
||||
fmtMessage.Content = content
|
||||
@@ -135,10 +139,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
fmtMessage.Content = content
|
||||
}
|
||||
formatMessages = append(formatMessages, fmtMessage)
|
||||
lastMessage = &textRequest.Messages[i]
|
||||
lastMessage = fmtMessage
|
||||
}
|
||||
|
||||
claudeMessages := make([]ClaudeMessage, 0)
|
||||
isFirstMessage := true
|
||||
for _, message := range formatMessages {
|
||||
if message.Role == "system" {
|
||||
if message.IsStringContent() {
|
||||
@@ -154,10 +159,54 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
claudeRequest.System = content
|
||||
}
|
||||
} else {
|
||||
if isFirstMessage {
|
||||
isFirstMessage = false
|
||||
if message.Role != "user" {
|
||||
// fix: first message is assistant, add user message
|
||||
claudeMessage := ClaudeMessage{
|
||||
Role: "user",
|
||||
Content: []ClaudeMediaMessage{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "...",
|
||||
},
|
||||
},
|
||||
}
|
||||
claudeMessages = append(claudeMessages, claudeMessage)
|
||||
}
|
||||
}
|
||||
claudeMessage := ClaudeMessage{
|
||||
Role: message.Role,
|
||||
}
|
||||
if message.IsStringContent() {
|
||||
if message.Role == "tool" {
|
||||
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
|
||||
lastMessage := claudeMessages[len(claudeMessages)-1]
|
||||
if content, ok := lastMessage.Content.(string); ok {
|
||||
lastMessage.Content = []ClaudeMediaMessage{
|
||||
{
|
||||
Type: "text",
|
||||
Text: content,
|
||||
},
|
||||
}
|
||||
}
|
||||
lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
|
||||
Type: "tool_result",
|
||||
ToolUseId: message.ToolCallId,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
claudeMessages[len(claudeMessages)-1] = lastMessage
|
||||
continue
|
||||
} else {
|
||||
claudeMessage.Role = "user"
|
||||
claudeMessage.Content = []ClaudeMediaMessage{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolUseId: message.ToolCallId,
|
||||
Content: message.StringContent(),
|
||||
},
|
||||
}
|
||||
}
|
||||
} else if message.IsStringContent() && message.ToolCalls == nil {
|
||||
claudeMessage.Content = message.StringContent()
|
||||
} else {
|
||||
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
|
||||
@@ -176,11 +225,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
|
||||
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
|
||||
claudeMediaMessage.Source.MediaType = mimeType
|
||||
claudeMediaMessage.Source.Data = data
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
|
||||
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -190,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
}
|
||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||
}
|
||||
if message.ToolCalls != nil {
|
||||
for _, tc := range message.ToolCalls.([]interface{}) {
|
||||
toolCallJSON, _ := json.Marshal(tc)
|
||||
var toolCall dto.ToolCall
|
||||
err := json.Unmarshal(toolCallJSON, &toolCall)
|
||||
if err != nil {
|
||||
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
|
||||
continue
|
||||
}
|
||||
inputObj := make(map[string]any)
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
|
||||
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
||||
continue
|
||||
}
|
||||
claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
|
||||
Type: "tool_use",
|
||||
Id: toolCall.ID,
|
||||
Name: toolCall.Function.Name,
|
||||
Input: inputObj,
|
||||
})
|
||||
}
|
||||
}
|
||||
claudeMessage.Content = claudeMediaMessages
|
||||
}
|
||||
claudeMessages = append(claudeMessages, claudeMessage)
|
||||
@@ -324,12 +395,13 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
if len(tools) > 0 {
|
||||
choice.Message.ToolCalls = tools
|
||||
}
|
||||
fullTextResponse.Model = claudeResponse.Model
|
||||
choices = append(choices, choice)
|
||||
fullTextResponse.Choices = choices
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
var usage *dto.Usage
|
||||
usage = &dto.Usage{}
|
||||
@@ -411,7 +483,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -437,15 +509,15 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model)
|
||||
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens = completionTokens
|
||||
usage.TotalTokens = promptTokens + completionTokens
|
||||
usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
} else {
|
||||
usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package cohere
|
||||
|
||||
var ModelList = []string{
|
||||
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
||||
"command-r", "command-r-plus",
|
||||
"command-r-08-2024", "command-r-plus-08-2024",
|
||||
"c4ai-aya-23-35b", "c4ai-aya-23-8b",
|
||||
"command-light", "command-light-nightly", "command", "command-nightly",
|
||||
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ type CohereRequest struct {
|
||||
Message string `json:"message"`
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
SafetyMode string `json:"safety_mode,omitempty"`
|
||||
}
|
||||
|
||||
type ChatHistory struct {
|
||||
|
||||
@@ -23,6 +23,9 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||
Stream: textRequest.Stream,
|
||||
MaxTokens: textRequest.GetMaxTokens(),
|
||||
}
|
||||
if common.CohereSafetySetting != "NONE" {
|
||||
cohereReq.SafetyMode = common.CohereSafetySetting
|
||||
}
|
||||
if cohereReq.MaxTokens == 0 {
|
||||
cohereReq.MaxTokens = 4000
|
||||
}
|
||||
@@ -44,6 +47,7 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &cohereReq
|
||||
}
|
||||
|
||||
|
||||
@@ -70,9 +70,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = geminiChatStreamHandler(c, resp, info)
|
||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = GeminiChatHandler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ const (
|
||||
|
||||
var ModelList = []string{
|
||||
"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
|
||||
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001",
|
||||
"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827",
|
||||
}
|
||||
|
||||
var ChannelName = "google gemini"
|
||||
|
||||
@@ -86,7 +86,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
@@ -94,7 +94,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
||||
},
|
||||
})
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
_, format, base64String, err := common.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -220,7 +220,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
||||
return &response
|
||||
}
|
||||
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseText := ""
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createAt := common.GetTimestamp()
|
||||
@@ -279,7 +279,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -32,7 +32,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
|
||||
}
|
||||
return "", errors.New("invalid relay mode")
|
||||
}
|
||||
@@ -58,6 +58,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = jinaRerankHandler(c, resp)
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
err, usage = jinaEmbeddingHandler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -33,3 +33,28 @@ func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &jinaResp.Usage
|
||||
}
|
||||
|
||||
func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var jinaResp dto.OpenAIEmbeddingResponse
|
||||
err = json.Unmarshal(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &jinaResp.Usage
|
||||
}
|
||||
|
||||
@@ -17,11 +17,25 @@ type OllamaRequest struct {
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Input []string `json:"input"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding,omitempty"`
|
||||
Embedding any `json:"embedding,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||
@@ -45,8 +44,15 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||
|
||||
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
|
||||
return &OllamaEmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Prompt: strings.Join(request.ParseInput(), " "),
|
||||
Model: request.Model,
|
||||
Input: request.ParseInput(),
|
||||
Options: &Options{
|
||||
Seed: int(request.Seed),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,6 +70,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if ollamaEmbeddingResponse.Error != "" {
|
||||
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
|
||||
}
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||
data = append(data, dto.OpenAIEmbeddingResponseItem{
|
||||
Embedding: ollamaEmbeddingResponse.Embedding,
|
||||
|
||||
@@ -78,6 +78,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
if info.ChannelType != common.ChannelTypeOpenAI {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
if "o1" == request.Model || strings.HasPrefix(request.Model, "o1-") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = 0
|
||||
}
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
package openai
|
||||
|
||||
var ModelList = []string{
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0613",
|
||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-vision-preview",
|
||||
"chatgpt-4o-latest",
|
||||
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06",
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"o1-preview", "o1-preview-2024-09-12",
|
||||
"o1-mini", "o1-mini-2024-09-12",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
"text-davinci-edit-001",
|
||||
"davinci-002", "babbage-002",
|
||||
"dall-e-3",
|
||||
"dall-e-2", "dall-e-3",
|
||||
"whisper-1",
|
||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||
}
|
||||
|
||||
83
relay/channel/siliconflow/adaptor.go
Normal file
83
relay/channel/siliconflow/adaptor.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package siliconflow
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
return "", errors.New("invalid relay mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = siliconflowRerankHandler(c, resp)
|
||||
case constant.RelayModeChatCompletions:
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
51
relay/channel/siliconflow/constant.go
Normal file
51
relay/channel/siliconflow/constant.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package siliconflow
|
||||
|
||||
var ModelList = []string{
|
||||
"THUDM/glm-4-9b-chat",
|
||||
//"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
//"TencentARC/PhotoMaker",
|
||||
"InstantX/InstantID",
|
||||
//"stabilityai/stable-diffusion-2-1",
|
||||
//"stabilityai/sd-turbo",
|
||||
//"stabilityai/sdxl-turbo",
|
||||
"ByteDance/SDXL-Lightning",
|
||||
"deepseek-ai/deepseek-llm-67b-chat",
|
||||
"Qwen/Qwen1.5-14B-Chat",
|
||||
"Qwen/Qwen1.5-7B-Chat",
|
||||
"Qwen/Qwen1.5-110B-Chat",
|
||||
"Qwen/Qwen1.5-32B-Chat",
|
||||
"01-ai/Yi-1.5-6B-Chat",
|
||||
"01-ai/Yi-1.5-9B-Chat-16K",
|
||||
"01-ai/Yi-1.5-34B-Chat-16K",
|
||||
"THUDM/chatglm3-6b",
|
||||
"deepseek-ai/DeepSeek-V2-Chat",
|
||||
"Qwen/Qwen2-72B-Instruct",
|
||||
"Qwen/Qwen2-7B-Instruct",
|
||||
"Qwen/Qwen2-57B-A14B-Instruct",
|
||||
//"stabilityai/stable-diffusion-3-medium",
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
|
||||
"Qwen/Qwen2-1.5B-Instruct",
|
||||
"internlm/internlm2_5-7b-chat",
|
||||
"BAAI/bge-large-en-v1.5",
|
||||
"BAAI/bge-large-zh-v1.5",
|
||||
"Pro/Qwen/Qwen2-7B-Instruct",
|
||||
"Pro/Qwen/Qwen2-1.5B-Instruct",
|
||||
"Pro/Qwen/Qwen1.5-7B-Chat",
|
||||
"Pro/THUDM/glm-4-9b-chat",
|
||||
"Pro/THUDM/chatglm3-6b",
|
||||
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
|
||||
"Pro/01-ai/Yi-1.5-6B-Chat",
|
||||
"Pro/google/gemma-2-9b-it",
|
||||
"Pro/internlm/internlm2_5-7b-chat",
|
||||
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
"iic/SenseVoiceSmall",
|
||||
"netease-youdao/bce-embedding-base_v1",
|
||||
"BAAI/bge-m3",
|
||||
"internlm/internlm2_5-20b-chat",
|
||||
"Qwen/Qwen2-Math-72B-Instruct",
|
||||
"netease-youdao/bce-reranker-base_v1",
|
||||
"BAAI/bge-reranker-v2-m3",
|
||||
}
|
||||
var ChannelName = "siliconflow"
|
||||
17
relay/channel/siliconflow/dto.go
Normal file
17
relay/channel/siliconflow/dto.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package siliconflow
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type SFTokens struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type SFMeta struct {
|
||||
Tokens SFTokens `json:"tokens"`
|
||||
}
|
||||
|
||||
type SFRerankResponse struct {
|
||||
Results []dto.RerankResponseDocument `json:"results"`
|
||||
Meta SFMeta `json:"meta"`
|
||||
}
|
||||
44
relay/channel/siliconflow/relay-siliconflow.go
Normal file
44
relay/channel/siliconflow/relay-siliconflow.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package siliconflow
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var siliconflowResp SFRerankResponse
|
||||
err = json.Unmarshal(responseBody, &siliconflowResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
|
||||
CompletionTokens: siliconflowResp.Meta.Tokens.OutputTokens,
|
||||
TotalTokens: siliconflowResp.Meta.Tokens.InputTokens + siliconflowResp.Meta.Tokens.OutputTokens,
|
||||
}
|
||||
rerankResp := &dto.RerankResponse{
|
||||
Results: siliconflowResp.Results,
|
||||
Usage: *usage,
|
||||
}
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, usage
|
||||
}
|
||||
184
relay/channel/vertex/adaptor.go
Normal file
184
relay/channel/vertex/adaptor.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/copier"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
RequestModeClaude = 1
|
||||
RequestModeGemini = 2
|
||||
RequestModeLlama = 3
|
||||
)
|
||||
|
||||
var claudeModelMap = map[string]string{
|
||||
"claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
|
||||
"claude-3-opus-20240229": "claude-3-opus@20240229",
|
||||
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
|
||||
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
|
||||
type Adaptor struct {
|
||||
RequestMode int
|
||||
AccountCredentials Credentials
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude") {
|
||||
a.RequestMode = RequestModeClaude
|
||||
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
|
||||
a.RequestMode = RequestModeGemini
|
||||
} else if strings.Contains(info.UpstreamModelName, "llama") {
|
||||
a.RequestMode = RequestModeLlama
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &Credentials{}
|
||||
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||
}
|
||||
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
|
||||
a.AccountCredentials = *adc
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
if info.IsStream {
|
||||
suffix = "streamGenerateContent?alt=sse"
|
||||
} else {
|
||||
suffix = "generateContent"
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
if info.IsStream {
|
||||
suffix = "streamRawPredict?alt=sse"
|
||||
} else {
|
||||
suffix = "rawPredict"
|
||||
}
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
info.UpstreamModelName = v
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
}
|
||||
return "", errors.New("unsupported request mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if a.RequestMode == RequestModeClaude {
|
||||
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vertexClaudeReq := &VertexAIClaudeRequest{
|
||||
AnthropicVersion: anthropicVersion,
|
||||
}
|
||||
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
|
||||
return nil, errors.New("failed to copy claude request")
|
||||
}
|
||||
c.Set("request_model", request.Model)
|
||||
return vertexClaudeReq, nil
|
||||
} else if a.RequestMode == RequestModeGemini {
|
||||
geminiRequest := gemini.CovertGemini2OpenAI(*request)
|
||||
c.Set("request_model", request.Model)
|
||||
return geminiRequest, nil
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return request, nil
|
||||
}
|
||||
return nil, errors.New("unsupported request mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
case RequestModeGemini:
|
||||
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
}
|
||||
} else {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
||||
case RequestModeGemini:
|
||||
err, usage = gemini.GeminiChatHandler(c, resp)
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
15
relay/channel/vertex/constants.go
Normal file
15
relay/channel/vertex/constants.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package vertex
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
|
||||
//"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
|
||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
|
||||
|
||||
"meta/llama3-405b-instruct-maas",
|
||||
}
|
||||
|
||||
var ChannelName = "vertex-ai"
|
||||
17
relay/channel/vertex/dto.go
Normal file
17
relay/channel/vertex/dto.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package vertex
|
||||
|
||||
import "one-api/relay/channel/claude"
|
||||
|
||||
type VertexAIClaudeRequest struct {
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
Messages []claude.ClaudeMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools []claude.Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
}
|
||||
16
relay/channel/vertex/relay-vertex.go
Normal file
16
relay/channel/vertex/relay-vertex.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package vertex
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
func GetModelRegion(other string, localModelName string) string {
|
||||
// if other is json string
|
||||
if common.IsJsonStr(other) {
|
||||
m := common.StrToMap(other)
|
||||
if m[localModelName] != nil {
|
||||
return m[localModelName].(string)
|
||||
} else {
|
||||
return m["default"].(string)
|
||||
}
|
||||
}
|
||||
return other
|
||||
}
|
||||
122
relay/channel/vertex/service_account.go
Normal file
122
relay/channel/vertex/service_account.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/cache/asynccache"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Credentials struct {
|
||||
ProjectID string `json:"project_id"`
|
||||
PrivateKeyID string `json:"private_key_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
ClientEmail string `json:"client_email"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
var Cache = asynccache.NewAsyncCache(asynccache.Options{
|
||||
RefreshDuration: time.Minute * 35,
|
||||
EnableExpire: true,
|
||||
ExpireDuration: time.Minute * 30,
|
||||
Fetcher: func(key string) (interface{}, error) {
|
||||
return nil, errors.New("not found")
|
||||
},
|
||||
})
|
||||
|
||||
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
|
||||
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
|
||||
val, err := Cache.Get(cacheKey)
|
||||
if err == nil {
|
||||
return val.(string), nil
|
||||
}
|
||||
|
||||
signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create signed JWT: %w", err)
|
||||
}
|
||||
newToken, err := exchangeJwtForAccessToken(signedJWT)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
|
||||
}
|
||||
if err := Cache.SetDefault(cacheKey, newToken); err {
|
||||
return newToken, nil
|
||||
}
|
||||
return newToken, nil
|
||||
}
|
||||
|
||||
func createSignedJWT(email, privateKeyPEM string) (string, error) {
|
||||
|
||||
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
|
||||
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
|
||||
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
|
||||
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
|
||||
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
|
||||
|
||||
block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("failed to parse PEM block containing the private key")
|
||||
}
|
||||
|
||||
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("not an RSA private key")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": email,
|
||||
"scope": "https://www.googleapis.com/auth/cloud-platform",
|
||||
"aud": "https://www.googleapis.com/oauth2/v4/token",
|
||||
"exp": now.Add(time.Minute * 35).Unix(),
|
||||
"iat": now.Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
signedToken, err := token.SignedString(rsaPrivateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signedToken, nil
|
||||
}
|
||||
|
||||
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
|
||||
|
||||
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||
data.Set("assertion", signedJWT)
|
||||
|
||||
resp, err := http.PostForm(authURL, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if accessToken, ok := result["access_token"].(string); ok {
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package zhipu_4v
|
||||
|
||||
var ModelList = []string{
|
||||
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools",
|
||||
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", "glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flash", "glm-4v-plus",
|
||||
}
|
||||
|
||||
var ChannelName = "zhipu_4v"
|
||||
|
||||
@@ -22,6 +22,7 @@ type RelayInfo struct {
|
||||
IsStream bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
@@ -57,6 +58,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
StartTime: startTime,
|
||||
FirstResponseTime: startTime.Add(-time.Second),
|
||||
OriginModelName: c.GetString("original_model"),
|
||||
UpstreamModelName: c.GetString("original_model"),
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
@@ -68,6 +71,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
info.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeVertexAi {
|
||||
info.ApiVersion = c.GetString("region")
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
|
||||
info.ChannelType == common.ChannelCloudflare {
|
||||
|
||||
@@ -23,6 +23,8 @@ const (
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
APITypeSiliconFlow
|
||||
APITypeVertexAi
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -66,6 +68,10 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeJina
|
||||
case common.ChannelCloudflare:
|
||||
apiType = APITypeCloudflare
|
||||
case common.ChannelTypeSiliconFlow:
|
||||
apiType = APITypeSiliconFlow
|
||||
case common.ChannelTypeVertexAi:
|
||||
apiType = APITypeVertexAi
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
||||
@@ -38,9 +38,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
if imageRequest.Model == "" {
|
||||
imageRequest.Model = "dall-e-2"
|
||||
}
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "standard"
|
||||
}
|
||||
|
||||
// Not "256x256", "512x512", or "1024x1024"
|
||||
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
|
||||
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
||||
@@ -50,6 +48,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
|
||||
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
|
||||
}
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "standard"
|
||||
}
|
||||
//if imageRequest.N != 1 {
|
||||
// return nil, errors.New("n must be 1")
|
||||
//}
|
||||
|
||||
@@ -30,7 +30,7 @@ func RelayMidjourneyImage(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(midjourneyTask.ImageUrl)
|
||||
resp, err := common.ProxiedHttpGet(midjourneyTask.ImageUrl, common.OutProxyUrl)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "http_get_image_failed",
|
||||
@@ -111,7 +111,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
|
||||
midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
|
||||
}
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
case relayconstant.RelayModeModerations:
|
||||
if textRequest.Input == "" {
|
||||
if textRequest.Input == "" || textRequest.Input == nil {
|
||||
return nil, errors.New("field input is required")
|
||||
}
|
||||
case relayconstant.RelayModeEdits:
|
||||
@@ -205,6 +205,15 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
prompts := textRequest.Prompt
|
||||
switch v := prompts.(type) {
|
||||
case string:
|
||||
prompts = v + textRequest.Suffix
|
||||
case []string:
|
||||
prompts = append(v, textRequest.Suffix)
|
||||
}
|
||||
|
||||
promptTokens, err = service.CountTokenInput(prompts, textRequest.Model)
|
||||
case relayconstant.RelayModeModerations:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
@@ -317,7 +326,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
totalTokens := promptTokens + completionTokens
|
||||
var logContent string
|
||||
if !usePrice {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||
}
|
||||
@@ -353,9 +362,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
} else if strings.HasPrefix(logModel, "g-") {
|
||||
logModel = "g-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
if extraContent != "" {
|
||||
|
||||
@@ -16,8 +16,10 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/palm"
|
||||
"one-api/relay/channel/perplexity"
|
||||
"one-api/relay/channel/siliconflow"
|
||||
"one-api/relay/channel/task/suno"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/vertex"
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/channel/zhipu_4v"
|
||||
@@ -62,6 +64,10 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &jina.Adaptor{}
|
||||
case constant.APITypeCloudflare:
|
||||
return &cloudflare.Adaptor{}
|
||||
case constant.APITypeSiliconFlow:
|
||||
return &siliconflow.Adaptor{}
|
||||
case constant.APITypeVertexAi:
|
||||
return &vertex.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -38,6 +38,23 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
if len(rerankRequest.Documents) == 0 {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
//isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[rerankRequest.Model] != "" {
|
||||
rerankRequest.Model = modelMap[rerankRequest.Model]
|
||||
// set upstream model name
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
|
||||
relayInfo.UpstreamModelName = rerankRequest.Model
|
||||
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
@@ -18,13 +18,14 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
|
||||
apiRouter.GET("/notice", controller.GetNotice)
|
||||
apiRouter.GET("/about", controller.GetAbout)
|
||||
//apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
|
||||
apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
|
||||
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxDoOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
|
||||
@@ -32,13 +33,15 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
|
||||
|
||||
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
|
||||
|
||||
userRoute := apiRouter.Group("/user")
|
||||
{
|
||||
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
|
||||
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
|
||||
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
|
||||
userRoute.GET("/logout", controller.Logout)
|
||||
userRoute.GET("/epay/notify", controller.EpayNotify)
|
||||
userRoute.GET("/groups", controller.GetUserGroups)
|
||||
|
||||
selfRoute := userRoute.Group("/")
|
||||
selfRoute.Use(middleware.UserAuth())
|
||||
@@ -49,8 +52,8 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.DELETE("/self", controller.DeleteSelf)
|
||||
selfRoute.GET("/token", controller.GenerateAccessToken)
|
||||
selfRoute.GET("/aff", controller.GetAffCode)
|
||||
selfRoute.POST("/topup", controller.TopUp)
|
||||
selfRoute.POST("/pay", controller.RequestEpay)
|
||||
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
|
||||
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestPayLink)
|
||||
selfRoute.POST("/amount", controller.RequestAmount)
|
||||
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
||||
}
|
||||
@@ -64,7 +67,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
adminRoute.POST("/", controller.CreateUser)
|
||||
adminRoute.POST("/manage", controller.ManageUser)
|
||||
adminRoute.PUT("/", controller.UpdateUser)
|
||||
adminRoute.DELETE("/:id", controller.DeleteUser)
|
||||
adminRoute.DELETE("/:id", controller.HardDeleteUser)
|
||||
}
|
||||
}
|
||||
optionRoute := apiRouter.Group("/option")
|
||||
|
||||
@@ -54,6 +54,8 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
|
||||
switch err.Error.Type {
|
||||
case "insufficient_quota":
|
||||
return true
|
||||
case "insufficient_user_quota":
|
||||
return true
|
||||
// https://docs.anthropic.com/claude/reference/errors
|
||||
case "authentication_error":
|
||||
return true
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"one-api/constant"
|
||||
)
|
||||
|
||||
func GetCallbackAddress() string {
|
||||
if constant.CustomCallbackAddress == "" {
|
||||
return constant.ServerAddress
|
||||
}
|
||||
return constant.CustomCallbackAddress
|
||||
}
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||
var cl200kTokenEncoder *tiktoken.Tiktoken
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
@@ -31,19 +30,22 @@ func InitTokenEncoders() {
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
|
||||
|
||||
gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
||||
}
|
||||
for model, _ := range common.GetDefaultModelRatioMap() {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4o") {
|
||||
tokenEncoderMap[model] = gpt4oTokenEncoder
|
||||
} else if strings.HasPrefix(model, "chatgpt-4o") {
|
||||
tokenEncoderMap[model] = gpt4oTokenEncoder
|
||||
} else if "o1" == model || strings.HasPrefix(model, "o1") {
|
||||
tokenEncoderMap[model] = gpt4oTokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
if strings.HasPrefix(model, "gpt-4o") {
|
||||
tokenEncoderMap[model] = cl200kTokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||
}
|
||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = nil
|
||||
}
|
||||
@@ -51,13 +53,6 @@ func InitTokenEncoders() {
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
if strings.HasPrefix(model, "gpt-4o") {
|
||||
return cl200kTokenEncoder
|
||||
}
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
tokenEncoder, ok := tokenEncoderMap[model]
|
||||
if ok && tokenEncoder != nil {
|
||||
@@ -68,13 +63,12 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||
tokenEncoder = getModelDefaultTokenEncoder(model)
|
||||
tokenEncoder = defaultTokenEncoder
|
||||
}
|
||||
tokenEncoderMap[model] = tokenEncoder
|
||||
return tokenEncoder
|
||||
}
|
||||
// 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder
|
||||
return getModelDefaultTokenEncoder(model)
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
@@ -111,10 +105,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
||||
var err error
|
||||
var format string
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
config, format, err = DecodeUrlImageData(imageUrl.Url)
|
||||
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
|
||||
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, _, err = DecodeBase64ImageData(imageUrl.Url)
|
||||
config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
|
||||
if constant.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
|
||||
workerUrl := constant.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
}
|
||||
// post request to worker
|
||||
data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
|
||||
return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
5
web/.gitignore
vendored
5
web/.gitignore
vendored
@@ -10,6 +10,7 @@
|
||||
|
||||
# production
|
||||
/build
|
||||
/dist
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
@@ -21,6 +22,4 @@
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.idea
|
||||
package-lock.json
|
||||
yarn.lock
|
||||
.idea/
|
||||
@@ -1 +1 @@
|
||||
module.exports = require("@so1ve/prettier-config");
|
||||
module.exports = require('@so1ve/prettier-config');
|
||||
|
||||
1897
web/pnpm-lock.yaml
generated
1897
web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
400
web/src/App.js
400
web/src/App.js
@@ -11,6 +11,7 @@ import EditUser from './pages/User/EditUser';
|
||||
import { getLogo, getSystemName } from './helpers';
|
||||
import PasswordResetForm from './components/PasswordResetForm';
|
||||
import GitHubOAuth from './components/GitHubOAuth';
|
||||
import LinuxDoOAuth from './components/LinuxDoOAuth';
|
||||
import PasswordResetConfirm from './components/PasswordResetConfirm';
|
||||
import { UserContext } from './context/User';
|
||||
import Channel from './pages/Channel';
|
||||
@@ -20,11 +21,11 @@ import Redemption from './pages/Redemption';
|
||||
import TopUp from './pages/TopUp';
|
||||
import Log from './pages/Log';
|
||||
import Chat from './pages/Chat';
|
||||
import Chat2Link from './pages/Chat2Link';
|
||||
import { Layout } from '@douyinfe/semi-ui';
|
||||
import Midjourney from './pages/Midjourney';
|
||||
import Pricing from './pages/Pricing/index.js';
|
||||
import Task from "./pages/Task/index.js";
|
||||
// import Detail from './pages/Detail';
|
||||
import Task from './pages/Task/index.js';
|
||||
|
||||
const Home = lazy(() => import('./pages/Home'));
|
||||
const Detail = lazy(() => import('./pages/Detail'));
|
||||
@@ -58,207 +59,224 @@ function App() {
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Layout>
|
||||
<Layout.Content>
|
||||
<Routes>
|
||||
<Route
|
||||
path='/'
|
||||
element={
|
||||
<>
|
||||
<Routes>
|
||||
<Route
|
||||
path='/'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Home />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/channel'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Channel />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/channel/edit/:id'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditChannel />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/channel/add'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditChannel />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/token'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Token />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/redemption'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Redemption />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<User />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user/edit/:id'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditUser />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user/edit'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditUser />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user/reset'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<PasswordResetConfirm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/login'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<LoginForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/register'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<RegisterForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/reset'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<PasswordResetForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/github'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<GitHubOAuth />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/linuxdo'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<LinuxDoOAuth />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/setting'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Home />
|
||||
<Setting />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/channel'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Channel />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/channel/edit/:id'
|
||||
element={
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/topup'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditChannel />
|
||||
<TopUp />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/channel/add'
|
||||
element={
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/log'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Log />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/detail'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditChannel />
|
||||
<Detail />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/token'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Token />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/redemption'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Redemption />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<User />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user/edit/:id'
|
||||
element={
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/midjourney'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditUser />
|
||||
<Midjourney />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user/edit'
|
||||
element={
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/task'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<EditUser />
|
||||
<Task />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/user/reset'
|
||||
element={
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/pricing'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Pricing />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/about'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<About />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/chat'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Chat />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
{/* 方便使用chat2link直接跳转聊天... */}
|
||||
<Route
|
||||
path='/chat2link'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<PasswordResetConfirm />
|
||||
<Chat2Link />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/login'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<LoginForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/register'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<RegisterForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/reset'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<PasswordResetForm />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/github'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<GitHubOAuth />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/setting'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Setting />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/topup'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<TopUp />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/log'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Log />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/detail'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Detail />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/midjourney'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Midjourney />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/task'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Task />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/pricing'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Pricing />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/about'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<About />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/chat'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<Chat />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route path='*' element={<NotFound />} />
|
||||
</Routes>
|
||||
</Layout.Content>
|
||||
</Layout>
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route path='*' element={<NotFound />} />
|
||||
</Routes>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -98,14 +98,18 @@ const ChannelsTable = () => {
|
||||
render: (text, record, index) => {
|
||||
if (text === 3) {
|
||||
if (record.other_info === '') {
|
||||
record.other_info = '{}'
|
||||
record.other_info = '{}';
|
||||
}
|
||||
let otherInfo = JSON.parse(record.other_info);
|
||||
let reason = otherInfo['status_reason'];
|
||||
let time = otherInfo['status_time'];
|
||||
return (
|
||||
<div>
|
||||
<Tooltip content={'原因:' + reason + ',时间:' + timestamp2string(time)}>
|
||||
<Tooltip
|
||||
content={
|
||||
'原因:' + reason + ',时间:' + timestamp2string(time)
|
||||
}
|
||||
>
|
||||
{renderStatus(text)}
|
||||
</Tooltip>
|
||||
</div>
|
||||
@@ -745,7 +749,7 @@ const ChannelsTable = () => {
|
||||
<Form.Select
|
||||
field='group'
|
||||
label='分组'
|
||||
optionList={[{ label: '选择分组', value: null}, ...groupOptions]}
|
||||
optionList={[{ label: '选择分组', value: null }, ...groupOptions]}
|
||||
initValue={null}
|
||||
onChange={(v) => {
|
||||
setSearchGroup(v);
|
||||
|
||||
@@ -3,7 +3,7 @@ import React, { useEffect, useState } from 'react';
|
||||
import { getFooterHTML, getSystemName } from '../helpers';
|
||||
import { Layout, Tooltip } from '@douyinfe/semi-ui';
|
||||
|
||||
const Footer = () => {
|
||||
const FooterBar = () => {
|
||||
const systemName = getSystemName();
|
||||
const [footer, setFooter] = useState(getFooterHTML());
|
||||
let remainCheckTimes = 5;
|
||||
@@ -25,11 +25,7 @@ const Footer = () => {
|
||||
New API {import.meta.env.VITE_REACT_APP_VERSION}{' '}
|
||||
</a>
|
||||
由{' '}
|
||||
<a
|
||||
href='https://github.com/Calcium-Ion'
|
||||
target='_blank'
|
||||
rel='noreferrer'
|
||||
>
|
||||
<a href='https://github.com/Calcium-Ion' target='_blank' rel='noreferrer'>
|
||||
Calcium-Ion
|
||||
</a>{' '}
|
||||
开发,基于{' '}
|
||||
@@ -56,21 +52,17 @@ const Footer = () => {
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Layout>
|
||||
<Layout.Content style={{ textAlign: 'center' }}>
|
||||
{footer ? (
|
||||
<Tooltip content={defaultFooter}>
|
||||
<div
|
||||
className='custom-footer'
|
||||
dangerouslySetInnerHTML={{ __html: footer }}
|
||||
></div>
|
||||
</Tooltip>
|
||||
) : (
|
||||
defaultFooter
|
||||
)}
|
||||
</Layout.Content>
|
||||
</Layout>
|
||||
<div style={{ textAlign: 'center' }}>
|
||||
{footer ? (
|
||||
<div
|
||||
className='custom-footer'
|
||||
dangerouslySetInnerHTML={{ __html: footer }}
|
||||
></div>
|
||||
) : (
|
||||
defaultFooter
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Footer;
|
||||
export default FooterBar;
|
||||
|
||||
@@ -14,9 +14,14 @@ const GitHubOAuth = () => {
|
||||
let navigate = useNavigate();
|
||||
|
||||
const sendCode = async (code, state, count) => {
|
||||
const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`);
|
||||
let aff = localStorage.getItem('aff');
|
||||
const res = await API.get(
|
||||
`/api/oauth/github?code=${code}&state=${state}&aff=${aff}`,
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
localStorage.removeItem('aff');
|
||||
|
||||
if (message === 'bind') {
|
||||
showSuccess('绑定成功!');
|
||||
navigate('/setting');
|
||||
@@ -41,6 +46,14 @@ const GitHubOAuth = () => {
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let error = searchParams.get('error');
|
||||
if (error) {
|
||||
let errorDescription = searchParams.get('error_description');
|
||||
showError(`授权错误:${error}: ${errorDescription}`);
|
||||
navigate('/setting');
|
||||
return;
|
||||
}
|
||||
|
||||
let code = searchParams.get('code');
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
|
||||
@@ -3,14 +3,23 @@ import { Link, useNavigate } from 'react-router-dom';
|
||||
import { UserContext } from '../context/User';
|
||||
import { useSetTheme, useTheme } from '../context/Theme';
|
||||
|
||||
import { API, getLogo, getSystemName, showSuccess } from '../helpers';
|
||||
import { API, getLogo, getSystemName, isMobile, showSuccess } from '../helpers';
|
||||
import '../index.css';
|
||||
|
||||
import fireworks from 'react-fireworks';
|
||||
|
||||
import { IconHelpCircle, IconKey, IconUser } from '@douyinfe/semi-icons';
|
||||
import {
|
||||
IconHelpCircle,
|
||||
IconHome,
|
||||
IconHomeStroked,
|
||||
IconKey,
|
||||
IconNoteMoneyStroked,
|
||||
IconPriceTag,
|
||||
IconUser,
|
||||
} from '@douyinfe/semi-icons';
|
||||
import { Avatar, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui';
|
||||
import { stringToColor } from '../helpers/render';
|
||||
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
||||
|
||||
// HeaderBar Buttons
|
||||
let headerButtons = [
|
||||
@@ -22,6 +31,21 @@ let headerButtons = [
|
||||
},
|
||||
];
|
||||
|
||||
let buttons = [
|
||||
{
|
||||
text: '首页',
|
||||
itemKey: 'home',
|
||||
to: '/',
|
||||
// icon: <IconHomeStroked />,
|
||||
},
|
||||
// {
|
||||
// text: '模型价格',
|
||||
// itemKey: 'pricing',
|
||||
// to: '/pricing',
|
||||
// icon: <IconNoteMoneyStroked />,
|
||||
// },
|
||||
];
|
||||
|
||||
if (localStorage.getItem('chat_link')) {
|
||||
headerButtons.splice(1, 0, {
|
||||
name: '聊天',
|
||||
@@ -90,6 +114,7 @@ const HeaderBar = () => {
|
||||
about: '/about',
|
||||
login: '/login',
|
||||
register: '/register',
|
||||
home: '/',
|
||||
};
|
||||
return (
|
||||
<Link
|
||||
@@ -103,6 +128,23 @@ const HeaderBar = () => {
|
||||
selectedKeys={[]}
|
||||
// items={headerButtons}
|
||||
onSelect={(key) => {}}
|
||||
header={
|
||||
isMobile()
|
||||
? {
|
||||
logo: (
|
||||
<img
|
||||
src={logo}
|
||||
alt='logo'
|
||||
style={{ marginRight: '0.75em' }}
|
||||
/>
|
||||
),
|
||||
}
|
||||
: {
|
||||
logo: <img src={logo} alt='logo' />,
|
||||
text: systemName,
|
||||
}
|
||||
}
|
||||
items={buttons}
|
||||
footer={
|
||||
<>
|
||||
{isNewYear && (
|
||||
@@ -121,15 +163,19 @@ const HeaderBar = () => {
|
||||
</Dropdown>
|
||||
)}
|
||||
<Nav.Item itemKey={'about'} icon={<IconHelpCircle />} />
|
||||
<Switch
|
||||
checkedText='🌞'
|
||||
size={'large'}
|
||||
checked={theme === 'dark'}
|
||||
uncheckedText='🌙'
|
||||
onChange={(checked) => {
|
||||
setTheme(checked);
|
||||
}}
|
||||
/>
|
||||
<>
|
||||
{!isMobile() && (
|
||||
<Switch
|
||||
checkedText='🌞'
|
||||
size={'large'}
|
||||
checked={theme === 'dark'}
|
||||
uncheckedText='🌙'
|
||||
onChange={(checked) => {
|
||||
setTheme(checked);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
{userState.user ? (
|
||||
<>
|
||||
<Dropdown
|
||||
@@ -155,7 +201,7 @@ const HeaderBar = () => {
|
||||
<Nav.Item
|
||||
itemKey={'login'}
|
||||
text={'登录'}
|
||||
icon={<IconKey />}
|
||||
// icon={<IconKey />}
|
||||
/>
|
||||
<Nav.Item
|
||||
itemKey={'register'}
|
||||
|
||||
27
web/src/components/LinuxDoIcon.js
Normal file
27
web/src/components/LinuxDoIcon.js
Normal file
@@ -0,0 +1,27 @@
|
||||
import React from 'react';
|
||||
import { Icon } from '@douyinfe/semi-ui';
|
||||
|
||||
const LinuxDoIcon = (props) => {
|
||||
function CustomIcon() {
|
||||
return (
|
||||
<svg
|
||||
className='icon'
|
||||
viewBox='0 0 24 24'
|
||||
version='1.1'
|
||||
xmlns='http://www.w3.org/2000/svg'
|
||||
width='1em'
|
||||
height='1em'
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d='M19.7,17.6c-0.1-0.2-0.2-0.4-0.2-0.6c0-0.4-0.2-0.7-0.5-1c-0.1-0.1-0.3-0.2-0.4-0.2c0.6-1.8-0.3-3.6-1.3-4.9c0,0,0,0,0,0c-0.8-1.2-2-2.1-1.9-3.7c0-1.9,0.2-5.4-3.3-5.1C8.5,2.3,9.5,6,9.4,7.3c0,1.1-0.5,2.2-1.3,3.1c-0.2,0.2-0.4,0.5-0.5,0.7c-1,1.2-1.5,2.8-1.5,4.3c-0.2,0.2-0.4,0.4-0.5,0.6c-0.1,0.1-0.2,0.2-0.2,0.3c-0.1,0.1-0.3,0.2-0.5,0.3c-0.4,0.1-0.7,0.3-0.9,0.7c-0.1,0.3-0.2,0.7-0.1,1.1c0.1,0.2,0.1,0.4,0,0.7c-0.2,0.4-0.2,0.9,0,1.4c0.3,0.4,0.8,0.5,1.5,0.6c0.5,0,1.1,0.2,1.6,0.4l0,0c0.5,0.3,1.1,0.5,1.7,0.5c0.3,0,0.7-0.1,1-0.2c0.3-0.2,0.5-0.4,0.6-0.7c0.4,0,1-0.2,1.7-0.2c0.6,0,1.2,0.2,2,0.1c0,0.1,0,0.2,0.1,0.3c0.2,0.5,0.7,0.9,1.3,1c0.1,0,0.1,0,0.2,0c0.8-0.1,1.6-0.5,2.1-1.1l0,0c0.4-0.4,0.9-0.7,1.4-0.9c0.6-0.3,1-0.5,1.1-1C20.3,18.6,20.1,18.2,19.7,17.6z M12.8,4.8c0.6,0.1,1.1,0.6,1,1.2c0,0.3-0.1,0.6-0.3,0.9c0,0,0,0-0.1,0c-0.2-0.1-0.3-0.1-0.4-0.2c0.1-0.1,0.1-0.3,0.2-0.5c0-0.4-0.2-0.7-0.4-0.7c-0.3,0-0.5,0.3-0.5,0.7c0,0,0,0.1,0,0.1c-0.1-0.1-0.3-0.1-0.4-0.2c0,0,0-0.1,0-0.1C11.8,5.5,12.2,4.9,12.8,4.8z M12.5,6.8c0.1,0.1,0.3,0.2,0.4,0.2c0.1,0,0.3,0.1,0.4,0.2c0.2,0.1,0.4,0.2,0.4,0.5c0,0.3-0.3,0.6-0.9,0.8c-0.2,0.1-0.3,0.1-0.4,0.2c-0.3,0.2-0.6,0.3-1,0.3c-0.3,0-0.6-0.2-0.8-0.4c-0.1-0.1-0.2-0.2-0.4-0.3C10.1,8.2,9.9,8,9.8,7.7c0-0.1,0.1-0.2,0.2-0.3c0.3-0.2,0.4-0.3,0.5-0.4l0.1-0.1c0.2-0.3,0.6-0.5,1-0.5C11.9,6.5,12.2,6.6,12.5,6.8z M10.4,5c0.4,0,0.7,0.4,0.8,1.1c0,0.1,0,0.1,0,0.2c-0.1,0-0.3,0.1-0.4,0.2c0,0,0-0.1,0-0.2c0-0.3-0.2-0.6-0.4-0.5c-0.2,0-0.3,0.3-0.3,0.6c0,0.2,0.1,0.3,0.2,0.4l0,0c0,0-0.1,0.1-0.2,0.1C9.9,6.7,9.7,6.4,9.7,6.1C9.7,5.5,10,5,10.4,5z M9.4,21.1c-0.7,0.3-1.6,0.2-2.2-0.2c-0.6-0.3-1.1-0.4-1.8-0.4c-0.5-0.1-1-0.1-1.1-0.3c-0.1-0.2-0.1-0.5,0.1-1c0.1-0.3,0.1-0.6,0-0.9c-0.1-0.3-0.1-0.5,0-0.8C4.5,17.2,4.7,17.1,5,17c0.3-0.1,0.5-0.2,0.7-0.4c0.1-0.1,0.2-0.2,0.3-0.4c0.3-0.4,0.5-0.6,0.8-0.6c0.6,0.1,1.1,1,1.5,1.9c0.2,0.3,0.4,0.7,0.7,1c0.4,0.5,0.9,1.2,0.9,1.6C9.9,20.6,9.7,20.9,9.4,21.1z M14.3,18.9c0,0.1,0,0.1-0.1,0.2c-1.2,0.9-2.8,1-4.1,0.3c-0.2-0.3-0.4-0.6-0.6-0.9c0.9-0.1,0.7-1.3-1.2-2.5c-2-1.3-0.6-3.7,0.1-4.8c0.1-0.1,0.1,0-0.3,0.8c-0.3,0.6-0.9,2.1-0.1,3.2c0-0.8,0.2-1.6,0.5-2.4c0.7-1.3,1.2-2.8,1.5-4.3c0.1,0.1,0.1,0.1,0.2,0.1c0.1,0.1,0.2,0.2,0.3,0.2c0.2,0.3,0.6,0.4,0.9,0.4c0,0,0.1,0,0.1,0c0.4,0,0.8-0.1,1.1-0.4c0.1-0.1,0.2-0.2,0.4-0.2c0.3-0.1,0.6-0.3,0.9-0.6c0.4,1.3,0.8,2.5,1.4,3.6c0.4,0.8,0.7,1.6,0.9,2.5c0.3,0,0.7,0.1,1,0.3c0.8,0.4,1.1,0.7,1,1.2c-0.1,0-0.1,0-0.2,0c0-0.3-0.2-0.6-0.9-0.9c-0.7-0.3-1.3-0.3-1.5,0.4c-0.1,0-0.2,0.1-0.3,0.1c-0.8,0.4-0.8,1.5-0.9,2.6C14.5,18.2,14.4,18.5,14.3,18.9z M18.9,19.5c-0.6,0.2-1.1,0.6-1.5,1.1c-0.4,0.6-1.1,1-1.9,0.9c-0.4,0-0.8-0.3-0.9-0.7c-0.1-0.6-0.1-1.2,0.2-1.8c0.1-0.4,0.2-0.7,0.3-1.1c0.1-1.2,0.1-1.9,0.6-2.2h0c0,0.5,0.3,0.8,0.7,1c0.5,0,1-0.1,1.4-0.5c0.1,0,0.1,0,0.2,0c0.3,0,0.5,0,0.7,0.2c0.2,0.2,0.3,0.5,0.3,0.7c0,0.3,0.2,0.6,0.3,0.9c0.5,0.5,0.5,0.8,0.5,0.9C19.7,19.1,19.3,19.3,18.9,19.5z M9.9,7.5c-0.1,0-0.1,0-0.1,0.1c0,0,0,0.1,0.1,0.1c0,0,0,0,0,0c0.1,0,0.1,0.1,0.1,0.1c0.3,0.4,0.8,0.6,1.4,0.7c0.5-0.1,1-0.2,1.5-0.6c0.2-0.1,0.4-0.2,0.6-0.3c0.1,0,0.1-0.1,0.1-0.1c0-0.1,0-0.1-0.1-0.1l0,0c-0.2,0.1-0.5,0.2-0.7,0.3c-0.4,0.3-0.9,0.5-1.4,0.5c-0.5,0-0.9-0.3-1.2-0.6C10.1,7.6,10,7.5,9.9,7.5z'
|
||||
fill='currentColor'
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
return <Icon svg={<CustomIcon />} />;
|
||||
};
|
||||
|
||||
export default LinuxDoIcon;
|
||||
71
web/src/components/LinuxDoOAuth.js
Normal file
71
web/src/components/LinuxDoOAuth.js
Normal file
@@ -0,0 +1,71 @@
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { API, showError, showSuccess } from '../helpers';
|
||||
import { UserContext } from '../context/User';
|
||||
|
||||
const LinuxDoOAuth = () => {
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
const [prompt, setPrompt] = useState('处理中...');
|
||||
const [processing, setProcessing] = useState(true);
|
||||
|
||||
let navigate = useNavigate();
|
||||
|
||||
const sendCode = async (code, state, count) => {
|
||||
let aff = localStorage.getItem('aff');
|
||||
const res = await API.get(
|
||||
`/api/oauth/linuxdo?code=${code}&state=${state}&aff=${aff}`,
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
localStorage.removeItem('aff');
|
||||
|
||||
if (message === 'bind') {
|
||||
showSuccess('绑定成功!');
|
||||
navigate('/setting');
|
||||
} else {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
localStorage.setItem('user', JSON.stringify(data));
|
||||
showSuccess('登录成功!');
|
||||
navigate('/');
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
if (count === 0) {
|
||||
setPrompt(`操作失败,重定向至登录界面中...`);
|
||||
navigate('/setting'); // in case this is failed to bind GitHub
|
||||
return;
|
||||
}
|
||||
count++;
|
||||
setPrompt(`出现错误,第 ${count} 次重试中...`);
|
||||
await new Promise((resolve) => setTimeout(resolve, count * 2000));
|
||||
await sendCode(code, state, count);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
let error = searchParams.get('error');
|
||||
if (error) {
|
||||
let errorDescription = searchParams.get('error_description');
|
||||
showError(`授权错误:${error}: ${errorDescription}`);
|
||||
navigate('/setting');
|
||||
return;
|
||||
}
|
||||
|
||||
let code = searchParams.get('code');
|
||||
let state = searchParams.get('state');
|
||||
sendCode(code, state, 0).then();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Segment style={{ minHeight: '300px' }}>
|
||||
<Dimmer active inverted>
|
||||
<Loader size='large'>{prompt}</Loader>
|
||||
</Dimmer>
|
||||
</Segment>
|
||||
);
|
||||
};
|
||||
|
||||
export default LinuxDoOAuth;
|
||||
@@ -1,8 +1,15 @@
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { UserContext } from '../context/User';
|
||||
import { API, getLogo, showError, showInfo, showSuccess, updateAPI } from '../helpers';
|
||||
import { onGitHubOAuthClicked } from './utils';
|
||||
import {
|
||||
API,
|
||||
getLogo,
|
||||
showError,
|
||||
showInfo,
|
||||
showSuccess,
|
||||
updateAPI,
|
||||
} from '../helpers';
|
||||
import { onGitHubOAuthClicked, onLinuxDoOAuthClicked } from './utils';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import {
|
||||
Button,
|
||||
@@ -18,6 +25,7 @@ import Text from '@douyinfe/semi-ui/lib/es/typography/text';
|
||||
import TelegramLoginButton from 'react-telegram-login';
|
||||
|
||||
import { IconGithubLogo } from '@douyinfe/semi-icons';
|
||||
import LinuxDoIcon from './LinuxDoIcon';
|
||||
import WeChatIcon from './WeChatIcon';
|
||||
import { setUserData } from '../helpers/data.js';
|
||||
|
||||
@@ -101,7 +109,7 @@ const LoginForm = () => {
|
||||
if (success) {
|
||||
userDispatch({ type: 'login', payload: data });
|
||||
setUserData(data);
|
||||
updateAPI()
|
||||
updateAPI();
|
||||
showSuccess('登录成功!');
|
||||
if (username === 'root' && password === '123456') {
|
||||
Modal.error({
|
||||
@@ -209,6 +217,7 @@ const LoginForm = () => {
|
||||
</Text>
|
||||
</div>
|
||||
{status.github_oauth ||
|
||||
status.linuxdo_oauth ||
|
||||
status.wechat_login ||
|
||||
status.telegram_oauth ? (
|
||||
<>
|
||||
@@ -233,35 +242,43 @@ const LoginForm = () => {
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
{status.linuxdo_oauth ? (
|
||||
<Button
|
||||
type='primary'
|
||||
icon={<LinuxDoIcon />}
|
||||
style={{ color: '#000', margin: '0 5px' }}
|
||||
onClick={() =>
|
||||
onLinuxDoOAuthClicked(status.linuxdo_client_id)
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
{status.wechat_login ? (
|
||||
<Button
|
||||
type='primary'
|
||||
style={{ color: 'rgba(var(--semi-green-5), 1)' }}
|
||||
style={{
|
||||
color: 'rgba(var(--semi-green-5), 1)',
|
||||
margin: '0 5px',
|
||||
}}
|
||||
icon={<Icon svg={<WeChatIcon />} />}
|
||||
onClick={onWeChatLoginClicked}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
|
||||
{status.telegram_oauth ? (
|
||||
<TelegramLoginButton
|
||||
className='semi-button semi-button-with-icon semi-button-with-icon-only'
|
||||
buttonSize='medium'
|
||||
dataOnauth={onTelegramLoginClicked}
|
||||
botName={status.telegram_bot_name}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
{status.telegram_oauth ? (
|
||||
<>
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
marginTop: 5,
|
||||
}}
|
||||
>
|
||||
<TelegramLoginButton
|
||||
dataOnauth={onTelegramLoginClicked}
|
||||
botName={status.telegram_bot_name}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<></>
|
||||
|
||||
@@ -92,9 +92,9 @@ function renderType(type) {
|
||||
);
|
||||
case 'UPLOAD':
|
||||
return (
|
||||
<Tag color='blue' size='large'>
|
||||
上传文件
|
||||
</Tag>
|
||||
<Tag color='blue' size='large'>
|
||||
上传文件
|
||||
</Tag>
|
||||
);
|
||||
case 'SHORTEN':
|
||||
return (
|
||||
@@ -262,7 +262,7 @@ function renderDuration(submit_time, finishTime) {
|
||||
|
||||
// 返回带有样式的颜色标签
|
||||
return (
|
||||
<Tag color={color} size="large">
|
||||
<Tag color={color} size='large'>
|
||||
{durationSec} 秒
|
||||
</Tag>
|
||||
);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user