mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-24 10:23:41 +08:00
Compare commits
19 Commits
v0.4.10-al
...
v0.5.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc2f48b1f2 | ||
|
|
889af8b2db | ||
|
|
4eea096654 | ||
|
|
4ab3211c0e | ||
|
|
3da119efba | ||
|
|
dccd66b852 | ||
|
|
2fcd6852e0 | ||
|
|
9b4d1964d4 | ||
|
|
806bf8241c | ||
|
|
ce93c9b6b2 | ||
|
|
4ec4289565 | ||
|
|
3dc5a0f91d | ||
|
|
80a846673a | ||
|
|
26c6719ea3 | ||
|
|
c87e05bfc2 | ||
|
|
e6938bd236 | ||
|
|
8f721d67a5 | ||
|
|
fcc1e2d568 | ||
|
|
9a1db61675 |
@@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
# One API
|
# One API
|
||||||
|
|
||||||
_✨ An OpenAI key management & redistribution system, easy to deploy & use ✨_
|
_✨ Access all LLM through the standard OpenAI API format, easy to deploy & use ✨_
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
51
README.md
51
README.md
@@ -11,7 +11,7 @@
|
|||||||
|
|
||||||
# One API
|
# One API
|
||||||
|
|
||||||
_✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用✨_
|
_✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -58,10 +58,13 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
|
|||||||
> **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。
|
> **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。
|
||||||
|
|
||||||
## 功能
|
## 功能
|
||||||
1. 支持多种 API 访问渠道:
|
1. 支持多种大模型:
|
||||||
+ [x] OpenAI 官方通道(支持配置镜像)
|
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
||||||
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
|
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
|
||||||
+ [x] **Azure OpenAI API**
|
+ [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
|
||||||
|
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||||
|
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
||||||
|
2. 支持配置镜像以及众多第三方代理服务:
|
||||||
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
|
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
|
||||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||||
+ [x] [API2D](https://api2d.com/r/197971)
|
+ [x] [API2D](https://api2d.com/r/197971)
|
||||||
@@ -69,32 +72,30 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
|
|||||||
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
|
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
|
||||||
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
|
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
|
||||||
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
|
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
|
||||||
2. 支持通过**负载均衡**的方式访问多个渠道。
|
3. 支持通过**负载均衡**的方式访问多个渠道。
|
||||||
3. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
||||||
4. 支持**多机部署**,[详见此处](#多机部署)。
|
5. 支持**多机部署**,[详见此处](#多机部署)。
|
||||||
5. 支持**令牌管理**,设置令牌的过期时间和额度。
|
6. 支持**令牌管理**,设置令牌的过期时间和额度。
|
||||||
6. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
|
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
|
||||||
7. 支持**通道管理**,批量创建通道。
|
8. 支持**通道管理**,批量创建通道。
|
||||||
8. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
|
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
|
||||||
9. 支持渠道**设置模型列表**。
|
10. 支持渠道**设置模型列表**。
|
||||||
10. 支持**查看额度明细**。
|
11. 支持**查看额度明细**。
|
||||||
11. 支持**用户邀请奖励**。
|
12. 支持**用户邀请奖励**。
|
||||||
12. 支持以美元为单位显示额度。
|
13. 支持以美元为单位显示额度。
|
||||||
13. 支持发布公告,设置充值链接,设置新用户初始额度。
|
14. 支持发布公告,设置充值链接,设置新用户初始额度。
|
||||||
14. 支持模型映射,重定向用户的请求模型。
|
15. 支持模型映射,重定向用户的请求模型。
|
||||||
15. 支持失败自动重试。
|
16. 支持失败自动重试。
|
||||||
16. 支持绘图接口。
|
17. 支持绘图接口。
|
||||||
17. 支持丰富的**自定义**设置,
|
18. 支持丰富的**自定义**设置,
|
||||||
1. 支持自定义系统名称,logo 以及页脚。
|
1. 支持自定义系统名称,logo 以及页脚。
|
||||||
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
||||||
18. 支持通过系统访问令牌访问管理 API。
|
19. 支持通过系统访问令牌访问管理 API。
|
||||||
19. 支持 Cloudflare Turnstile 用户校验。
|
20. 支持 Cloudflare Turnstile 用户校验。
|
||||||
20. 支持用户管理,支持**多种用户登录注册方式**:
|
21. 支持用户管理,支持**多种用户登录注册方式**:
|
||||||
+ 邮箱登录注册以及通过邮箱进行密码重置。
|
+ 邮箱登录注册以及通过邮箱进行密码重置。
|
||||||
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
||||||
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
||||||
21. 支持 [ChatGLM](https://github.com/THUDM/ChatGLM2-6B)。
|
|
||||||
22. 未来其他大模型开放 API 后,将第一时间支持,并将其封装成同样的 API 访问方式。
|
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
### 基于 Docker 进行部署
|
### 基于 Docker 进行部署
|
||||||
|
|||||||
@@ -77,6 +77,8 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
|||||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||||
|
|
||||||
|
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RoleGuestUser = 0
|
RoleGuestUser = 0
|
||||||
RoleCommonUser = 1
|
RoleCommonUser = 1
|
||||||
@@ -152,6 +154,8 @@ const (
|
|||||||
ChannelTypeAPI2GPT = 12
|
ChannelTypeAPI2GPT = 12
|
||||||
ChannelTypeAIGC2D = 13
|
ChannelTypeAIGC2D = 13
|
||||||
ChannelTypeAnthropic = 14
|
ChannelTypeAnthropic = 14
|
||||||
|
ChannelTypeBaidu = 15
|
||||||
|
ChannelTypeZhipu = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@@ -170,4 +174,6 @@ var ChannelBaseURLs = []string{
|
|||||||
"https://api.api2gpt.com", // 12
|
"https://api.api2gpt.com", // 12
|
||||||
"https://api.aigc2d.com", // 13
|
"https://api.aigc2d.com", // 13
|
||||||
"https://api.anthropic.com", // 14
|
"https://api.anthropic.com", // 14
|
||||||
|
"https://aip.baidubce.com", // 15
|
||||||
|
"https://open.bigmodel.cn", // 16
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ import "encoding/json"
|
|||||||
|
|
||||||
// ModelRatio
|
// ModelRatio
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
||||||
// https://openai.com/pricing
|
// https://openai.com/pricing
|
||||||
// TODO: when a new api is enabled, check the pricing here
|
// TODO: when a new api is enabled, check the pricing here
|
||||||
// 1 === $0.002 / 1K tokens
|
// 1 === $0.002 / 1K tokens
|
||||||
|
// 1 === ¥0.014 / 1k tokens
|
||||||
var ModelRatio = map[string]float64{
|
var ModelRatio = map[string]float64{
|
||||||
"gpt-4": 15,
|
"gpt-4": 15,
|
||||||
"gpt-4-0314": 15,
|
"gpt-4-0314": 15,
|
||||||
@@ -38,6 +40,12 @@ var ModelRatio = map[string]float64{
|
|||||||
"dall-e": 8,
|
"dall-e": 8,
|
||||||
"claude-instant-1": 0.75,
|
"claude-instant-1": 0.75,
|
||||||
"claude-2": 30,
|
"claude-2": 30,
|
||||||
|
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||||
|
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||||
|
"PaLM-2": 1,
|
||||||
|
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||||
|
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||||
|
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
func ModelRatio2JSONString() string {
|
||||||
|
|||||||
@@ -11,9 +11,11 @@ func GetSubscription(c *gin.Context) {
|
|||||||
var usedQuota int
|
var usedQuota int
|
||||||
var err error
|
var err error
|
||||||
var token *model.Token
|
var token *model.Token
|
||||||
|
var expiredTime int64
|
||||||
if common.DisplayTokenStatEnabled {
|
if common.DisplayTokenStatEnabled {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
token, err = model.GetTokenById(tokenId)
|
token, err = model.GetTokenById(tokenId)
|
||||||
|
expiredTime = token.ExpiredTime
|
||||||
remainQuota = token.RemainQuota
|
remainQuota = token.RemainQuota
|
||||||
usedQuota = token.UsedQuota
|
usedQuota = token.UsedQuota
|
||||||
} else {
|
} else {
|
||||||
@@ -21,6 +23,9 @@ func GetSubscription(c *gin.Context) {
|
|||||||
remainQuota, err = model.GetUserQuota(userId)
|
remainQuota, err = model.GetUserQuota(userId)
|
||||||
usedQuota, err = model.GetUserUsedQuota(userId)
|
usedQuota, err = model.GetUserUsedQuota(userId)
|
||||||
}
|
}
|
||||||
|
if expiredTime <= 0 {
|
||||||
|
expiredTime = 0
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openAIError := OpenAIError{
|
openAIError := OpenAIError{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
@@ -45,6 +50,7 @@ func GetSubscription(c *gin.Context) {
|
|||||||
SoftLimitUSD: amount,
|
SoftLimitUSD: amount,
|
||||||
HardLimitUSD: amount,
|
HardLimitUSD: amount,
|
||||||
SystemHardLimitUSD: amount,
|
SystemHardLimitUSD: amount,
|
||||||
|
AccessUntil: expiredTime,
|
||||||
}
|
}
|
||||||
c.JSON(200, subscription)
|
c.JSON(200, subscription)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ type OpenAISubscriptionResponse struct {
|
|||||||
SoftLimitUSD float64 `json:"soft_limit_usd"`
|
SoftLimitUSD float64 `json:"soft_limit_usd"`
|
||||||
HardLimitUSD float64 `json:"hard_limit_usd"`
|
HardLimitUSD float64 `json:"hard_limit_usd"`
|
||||||
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
|
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
|
||||||
|
AccessUntil int64 `json:"access_until"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIUsageDailyCost struct {
|
type OpenAIUsageDailyCost struct {
|
||||||
@@ -84,7 +85,6 @@ func GetAuthHeader(token string) http.Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
|
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
|
||||||
client := &http.Client{}
|
|
||||||
req, err := http.NewRequest(method, url, nil)
|
req, err := http.NewRequest(method, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -92,10 +92,13 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
|||||||
for k := range headers {
|
for k := range headers {
|
||||||
req.Header.Add(k, headers.Get(k))
|
req.Header.Add(k, headers.Get(k))
|
||||||
}
|
}
|
||||||
res, err := client.Do(req)
|
res, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("status code: %d", res.StatusCode)
|
||||||
|
}
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -16,6 +16,14 @@ import (
|
|||||||
|
|
||||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
|
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
|
case common.ChannelTypePaLM:
|
||||||
|
fallthrough
|
||||||
|
case common.ChannelTypeAnthropic:
|
||||||
|
fallthrough
|
||||||
|
case common.ChannelTypeBaidu:
|
||||||
|
fallthrough
|
||||||
|
case common.ChannelTypeZhipu:
|
||||||
|
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
request.Model = "gpt-35-turbo"
|
request.Model = "gpt-35-turbo"
|
||||||
default:
|
default:
|
||||||
@@ -45,8 +53,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
|
|||||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
client := &http.Client{}
|
resp, err := httpClient.Do(req)
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -127,8 +127,9 @@ func SendPasswordResetEmail(c *gin.Context) {
|
|||||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
|
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
|
||||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||||
"<p>点击<a href='%s'>此处</a>进行密码重置。</p>"+
|
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, common.VerificationValidMinutes)
|
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
|
||||||
|
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||||
err := common.SendEmail(subject, email, content)
|
err := common.SendEmail(subject, email, content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -252,24 +252,6 @@ func init() {
|
|||||||
Root: "code-davinci-edit-001",
|
Root: "code-davinci-edit-001",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Id: "ChatGLM",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "thudm",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "ChatGLM",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Id: "ChatGLM2",
|
|
||||||
Object: "model",
|
|
||||||
Created: 1677649963,
|
|
||||||
OwnedBy: "thudm",
|
|
||||||
Permission: permission,
|
|
||||||
Root: "ChatGLM2",
|
|
||||||
Parent: nil,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Id: "claude-instant-1",
|
Id: "claude-instant-1",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -288,6 +270,60 @@ func init() {
|
|||||||
Root: "claude-2",
|
Root: "claude-2",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "ERNIE-Bot",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "baidu",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "ERNIE-Bot",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "ERNIE-Bot-turbo",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "baidu",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "ERNIE-Bot-turbo",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "PaLM-2",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "PaLM-2",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "chatglm_pro",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "zhipu",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "chatglm_pro",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "chatglm_std",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "zhipu",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "chatglm_std",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "chatglm_lite",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "zhipu",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "chatglm_lite",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
openAIModelsMap = make(map[string]OpenAIModels)
|
openAIModelsMap = make(map[string]OpenAIModels)
|
||||||
for _, model := range openAIModels {
|
for _, model := range openAIModels {
|
||||||
|
|||||||
203
controller/relay-baidu.go
Normal file
203
controller/relay-baidu.go
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||||
|
|
||||||
|
type BaiduTokenResponse struct {
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
SessionKey string `json:"session_key"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
SessionSecret string `json:"session_secret"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatRequest struct {
|
||||||
|
Messages []BaiduMessage `json:"messages"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
UserId string `json:"user_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduError struct {
|
||||||
|
ErrorCode int `json:"error_code"`
|
||||||
|
ErrorMsg string `json:"error_msg"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
IsTruncated bool `json:"is_truncated"`
|
||||||
|
NeedClearHistory bool `json:"need_clear_history"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
BaiduError
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaiduChatStreamResponse struct {
|
||||||
|
BaiduChatResponse
|
||||||
|
SentenceId int `json:"sentence_id"`
|
||||||
|
IsEnd bool `json:"is_end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||||
|
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
messages = append(messages, BaiduMessage{
|
||||||
|
Role: message.Role,
|
||||||
|
Content: message.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &BaiduChatRequest{
|
||||||
|
Messages: messages,
|
||||||
|
Stream: request.Stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
|
||||||
|
choice := OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: response.Result,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
fullTextResponse := OpenAITextResponse{
|
||||||
|
Id: response.Id,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: response.Created,
|
||||||
|
Choices: []OpenAITextResponseChoice{choice},
|
||||||
|
Usage: response.Usage,
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = baiduResponse.Result
|
||||||
|
choice.FinishReason = "stop"
|
||||||
|
response := ChatCompletionsStreamResponse{
|
||||||
|
Id: baiduResponse.Id,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: baiduResponse.Created,
|
||||||
|
Model: "ernie-bot",
|
||||||
|
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var usage Usage
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 6 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[6:]
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
var baiduResponse BaiduChatStreamResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
usage.PromptTokens += baiduResponse.Usage.PromptTokens
|
||||||
|
usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
|
||||||
|
usage.TotalTokens += baiduResponse.Usage.TotalTokens
|
||||||
|
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var baiduResponse BaiduChatResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if baiduResponse.ErrorMsg != "" {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: baiduResponse.ErrorMsg,
|
||||||
|
Type: "baidu_error",
|
||||||
|
Param: "",
|
||||||
|
Code: baiduResponse.ErrorCode,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
@@ -109,8 +109,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
|
||||||
client := &http.Client{}
|
resp, err := httpClient.Do(req)
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*Ope
|
|||||||
}
|
}
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||||
// So the client will be confused by the response.
|
// So the httpClient will be confused by the response.
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
// For example, Postman will report error, and we cannot check the response at all.
|
||||||
for k, v := range resp.Header {
|
for k, v := range resp.Header {
|
||||||
c.Writer.Header().Set(k, v[0])
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||||
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||||
|
|
||||||
type PaLMChatMessage struct {
|
type PaLMChatMessage struct {
|
||||||
Author string `json:"author"`
|
Author string `json:"author"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
@@ -15,45 +22,188 @@ type PaLMFilter struct {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
type PaLMPrompt struct {
|
||||||
|
Messages []PaLMChatMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
type PaLMChatRequest struct {
|
type PaLMChatRequest struct {
|
||||||
Prompt []Message `json:"prompt"`
|
Prompt PaLMPrompt `json:"prompt"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
CandidateCount int `json:"candidateCount"`
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
TopP float64 `json:"topP"`
|
TopP float64 `json:"topP,omitempty"`
|
||||||
TopK int `json:"topK"`
|
TopK int `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaLMError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Status string `json:"status"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
|
||||||
type PaLMChatResponse struct {
|
type PaLMChatResponse struct {
|
||||||
Candidates []Message `json:"candidates"`
|
Candidates []PaLMChatMessage `json:"candidates"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
Filters []PaLMFilter `json:"filters"`
|
Filters []PaLMFilter `json:"filters"`
|
||||||
|
Error PaLMError `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorWithStatusCode {
|
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
|
palmRequest := PaLMChatRequest{
|
||||||
messages := make([]PaLMChatMessage, 0, len(openAIRequest.Messages))
|
Prompt: PaLMPrompt{
|
||||||
for _, message := range openAIRequest.Messages {
|
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
||||||
var author string
|
},
|
||||||
if message.Role == "user" {
|
Temperature: textRequest.Temperature,
|
||||||
author = "0"
|
CandidateCount: textRequest.N,
|
||||||
} else {
|
TopP: textRequest.TopP,
|
||||||
author = "1"
|
TopK: textRequest.MaxTokens,
|
||||||
}
|
}
|
||||||
messages = append(messages, PaLMChatMessage{
|
for _, message := range textRequest.Messages {
|
||||||
Author: author,
|
palmMessage := PaLMChatMessage{
|
||||||
Content: message.Content,
|
Content: message.Content,
|
||||||
})
|
}
|
||||||
|
if message.Role == "user" {
|
||||||
|
palmMessage.Author = "0"
|
||||||
|
} else {
|
||||||
|
palmMessage.Author = "1"
|
||||||
|
}
|
||||||
|
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
||||||
}
|
}
|
||||||
request := PaLMChatRequest{
|
return &palmRequest
|
||||||
Prompt: nil,
|
}
|
||||||
Temperature: openAIRequest.Temperature,
|
|
||||||
CandidateCount: openAIRequest.N,
|
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
|
||||||
TopP: openAIRequest.TopP,
|
fullTextResponse := OpenAITextResponse{
|
||||||
TopK: openAIRequest.MaxTokens,
|
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||||
}
|
}
|
||||||
// TODO: forward request to PaLM & convert response
|
for i, candidate := range response.Candidates {
|
||||||
fmt.Print(request)
|
choice := OpenAITextResponseChoice{
|
||||||
return nil
|
Index: i,
|
||||||
|
Message: Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: candidate.Content,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}
|
||||||
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
if len(palmResponse.Candidates) > 0 {
|
||||||
|
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||||
|
}
|
||||||
|
choice.FinishReason = "stop"
|
||||||
|
var response ChatCompletionsStreamResponse
|
||||||
|
response.Object = "chat.completion.chunk"
|
||||||
|
response.Model = "palm2"
|
||||||
|
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||||
|
responseText := ""
|
||||||
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
|
createdTime := common.GetTimestamp()
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error reading stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error closing stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var palmResponse PaLMChatResponse
|
||||||
|
err = json.Unmarshal(responseBody, &palmResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
|
||||||
|
fullTextResponse.Id = responseId
|
||||||
|
fullTextResponse.Created = createdTime
|
||||||
|
if len(palmResponse.Candidates) > 0 {
|
||||||
|
responseText = palmResponse.Candidates[0].Content
|
||||||
|
}
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
stopChan <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dataChan <- string(jsonResponse)
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + data})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
|
}
|
||||||
|
return nil, responseText
|
||||||
|
}
|
||||||
|
|
||||||
|
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var palmResponse PaLMChatResponse
|
||||||
|
err = json.Unmarshal(responseBody, &palmResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: palmResponse.Error.Message,
|
||||||
|
Type: palmResponse.Error.Status,
|
||||||
|
Param: "",
|
||||||
|
Code: palmResponse.Error.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
|
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
|
||||||
|
usage := Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: completionTokens,
|
||||||
|
TotalTokens: promptTokens + completionTokens,
|
||||||
|
}
|
||||||
|
fullTextResponse.Usage = usage
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,8 +18,16 @@ const (
|
|||||||
APITypeOpenAI = iota
|
APITypeOpenAI = iota
|
||||||
APITypeClaude
|
APITypeClaude
|
||||||
APITypePaLM
|
APITypePaLM
|
||||||
|
APITypeBaidu
|
||||||
|
APITypeZhipu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var httpClient *http.Client
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
httpClient = &http.Client{}
|
||||||
|
}
|
||||||
|
|
||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
@@ -79,6 +87,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
apiType := APITypeOpenAI
|
apiType := APITypeOpenAI
|
||||||
if strings.HasPrefix(textRequest.Model, "claude") {
|
if strings.HasPrefix(textRequest.Model, "claude") {
|
||||||
apiType = APITypeClaude
|
apiType = APITypeClaude
|
||||||
|
} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
|
||||||
|
apiType = APITypeBaidu
|
||||||
|
} else if strings.HasPrefix(textRequest.Model, "PaLM") {
|
||||||
|
apiType = APITypePaLM
|
||||||
|
} else if strings.HasPrefix(textRequest.Model, "chatglm_") {
|
||||||
|
apiType = APITypeZhipu
|
||||||
}
|
}
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
requestURL := c.Request.URL.String()
|
requestURL := c.Request.URL.String()
|
||||||
@@ -112,6 +126,29 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if baseURL != "" {
|
if baseURL != "" {
|
||||||
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
|
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
|
||||||
}
|
}
|
||||||
|
case APITypeBaidu:
|
||||||
|
switch textRequest.Model {
|
||||||
|
case "ERNIE-Bot":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||||
|
case "ERNIE-Bot-turbo":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||||
|
case "BLOOMZ-7B":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||||
|
}
|
||||||
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
|
fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
|
||||||
|
case APITypePaLM:
|
||||||
|
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
|
||||||
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
|
fullRequestURL += "?key=" + apiKey
|
||||||
|
case APITypeZhipu:
|
||||||
|
method := "invoke"
|
||||||
|
if textRequest.Stream {
|
||||||
|
method = "sse-invoke"
|
||||||
|
}
|
||||||
|
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||||
}
|
}
|
||||||
var promptTokens int
|
var promptTokens int
|
||||||
var completionTokens int
|
var completionTokens int
|
||||||
@@ -164,6 +201,27 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case APITypeBaidu:
|
||||||
|
baiduRequest := requestOpenAI2Baidu(textRequest)
|
||||||
|
jsonStr, err := json.Marshal(baiduRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case APITypePaLM:
|
||||||
|
palmRequest := requestOpenAI2PaLM(textRequest)
|
||||||
|
jsonStr, err := json.Marshal(palmRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case APITypeZhipu:
|
||||||
|
zhipuRequest := requestOpenAI2Zhipu(textRequest)
|
||||||
|
jsonStr, err := json.Marshal(zhipuRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -185,12 +243,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
anthropicVersion = "2023-06-01"
|
anthropicVersion = "2023-06-01"
|
||||||
}
|
}
|
||||||
req.Header.Set("anthropic-version", anthropicVersion)
|
req.Header.Set("anthropic-version", anthropicVersion)
|
||||||
|
case APITypeZhipu:
|
||||||
|
token := getZhipuToken(apiKey)
|
||||||
|
req.Header.Set("Authorization", token)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
||||||
client := &http.Client{}
|
resp, err := httpClient.Do(req)
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@@ -216,11 +276,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if strings.HasPrefix(textRequest.Model, "gpt-4") {
|
if strings.HasPrefix(textRequest.Model, "gpt-4") {
|
||||||
completionRatio = 2
|
completionRatio = 2
|
||||||
}
|
}
|
||||||
if isStream {
|
if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu {
|
||||||
completionTokens = countTokenText(streamResponseText, textRequest.Model)
|
completionTokens = countTokenText(streamResponseText, textRequest.Model)
|
||||||
} else {
|
} else {
|
||||||
promptTokens = textResponse.Usage.PromptTokens
|
promptTokens = textResponse.Usage.PromptTokens
|
||||||
completionTokens = textResponse.Usage.CompletionTokens
|
completionTokens = textResponse.Usage.CompletionTokens
|
||||||
|
if apiType == APITypeZhipu {
|
||||||
|
// zhipu's API does not return prompt tokens & completion tokens
|
||||||
|
promptTokens = textResponse.Usage.TotalTokens
|
||||||
|
}
|
||||||
}
|
}
|
||||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||||
quota = int(float64(quota) * ratio)
|
quota = int(float64(quota) * ratio)
|
||||||
@@ -266,7 +330,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
textResponse.Usage = *usage
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
case APITypeClaude:
|
case APITypeClaude:
|
||||||
@@ -282,7 +348,67 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
textResponse.Usage = *usage
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case APITypeBaidu:
|
||||||
|
if isStream {
|
||||||
|
err, usage := baiduStreamHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
err, usage := baiduHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case APITypePaLM:
|
||||||
|
if textRequest.Stream { // PaLM2 API does not support stream
|
||||||
|
err, responseText := palmStreamHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
streamResponseText = responseText
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case APITypeZhipu:
|
||||||
|
if isStream {
|
||||||
|
err, usage := zhipuStreamHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
err, usage := zhipuHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|||||||
290
controller/relay-zhipu.go
Normal file
290
controller/relay-zhipu.go
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://open.bigmodel.cn/doc/api#chatglm_std
|
||||||
|
// chatglm_std, chatglm_lite
|
||||||
|
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
||||||
|
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
||||||
|
|
||||||
|
type ZhipuMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuRequest struct {
|
||||||
|
Prompt []ZhipuMessage `json:"prompt"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
RequestId string `json:"request_id,omitempty"`
|
||||||
|
Incremental bool `json:"incremental,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuResponseData struct {
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
TaskStatus string `json:"task_status"`
|
||||||
|
Choices []ZhipuMessage `json:"choices"`
|
||||||
|
Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data ZhipuResponseData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ZhipuStreamMetaResponse struct {
|
||||||
|
RequestId string `json:"request_id"`
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
TaskStatus string `json:"task_status"`
|
||||||
|
Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type zhipuTokenData struct {
|
||||||
|
Token string
|
||||||
|
ExpiryTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var zhipuTokens sync.Map
|
||||||
|
var expSeconds int64 = 24 * 3600
|
||||||
|
|
||||||
|
func getZhipuToken(apikey string) string {
|
||||||
|
data, ok := zhipuTokens.Load(apikey)
|
||||||
|
if ok {
|
||||||
|
tokenData := data.(zhipuTokenData)
|
||||||
|
if time.Now().Before(tokenData.ExpiryTime) {
|
||||||
|
return tokenData.Token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
split := strings.Split(apikey, ".")
|
||||||
|
if len(split) != 2 {
|
||||||
|
common.SysError("invalid zhipu key: " + apikey)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
id := split[0]
|
||||||
|
secret := split[1]
|
||||||
|
|
||||||
|
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
||||||
|
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
||||||
|
|
||||||
|
timestamp := time.Now().UnixNano() / 1e6
|
||||||
|
|
||||||
|
payload := jwt.MapClaims{
|
||||||
|
"api_key": id,
|
||||||
|
"exp": expMillis,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||||
|
|
||||||
|
token.Header["alg"] = "HS256"
|
||||||
|
token.Header["sign_type"] = "SIGN"
|
||||||
|
|
||||||
|
tokenString, err := token.SignedString([]byte(secret))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
zhipuTokens.Store(apikey, zhipuTokenData{
|
||||||
|
Token: tokenString,
|
||||||
|
ExpiryTime: expiryTime,
|
||||||
|
})
|
||||||
|
|
||||||
|
return tokenString
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
||||||
|
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
||||||
|
for _, message := range request.Messages {
|
||||||
|
messages = append(messages, ZhipuMessage{
|
||||||
|
Role: message.Role,
|
||||||
|
Content: message.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &ZhipuRequest{
|
||||||
|
Prompt: messages,
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
Incremental: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
|
||||||
|
fullTextResponse := OpenAITextResponse{
|
||||||
|
Id: response.Data.TaskId,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
|
||||||
|
Usage: response.Data.Usage,
|
||||||
|
}
|
||||||
|
for i, choice := range response.Data.Choices {
|
||||||
|
openaiChoice := OpenAITextResponseChoice{
|
||||||
|
Index: i,
|
||||||
|
Message: Message{
|
||||||
|
Role: choice.Role,
|
||||||
|
Content: strings.Trim(choice.Content, "\""),
|
||||||
|
},
|
||||||
|
FinishReason: "",
|
||||||
|
}
|
||||||
|
if i == len(response.Data.Choices)-1 {
|
||||||
|
openaiChoice.FinishReason = "stop"
|
||||||
|
}
|
||||||
|
fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = zhipuResponse
|
||||||
|
choice.FinishReason = ""
|
||||||
|
response := ChatCompletionsStreamResponse{
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "chatglm",
|
||||||
|
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = ""
|
||||||
|
choice.FinishReason = "stop"
|
||||||
|
response := ChatCompletionsStreamResponse{
|
||||||
|
Id: zhipuResponse.RequestId,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "chatglm",
|
||||||
|
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||||
|
}
|
||||||
|
return &response, &zhipuResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var usage *Usage
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
metaChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
data = strings.Trim(data, "\"")
|
||||||
|
if len(data) < 5 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if data[:5] == "data:" {
|
||||||
|
dataChan <- data[5:]
|
||||||
|
} else if data[:5] == "meta:" {
|
||||||
|
metaChan <- data[5:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
response := streamResponseZhipu2OpenAI(data)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case data := <-metaChan:
|
||||||
|
var zhipuResponse ZhipuStreamMetaResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
usage = zhipuUsage
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var zhipuResponse ZhipuResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if !zhipuResponse.Success {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: zhipuResponse.Msg,
|
||||||
|
Type: "zhipu_error",
|
||||||
|
Param: "",
|
||||||
|
Code: zhipuResponse.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
@@ -3,12 +3,13 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LoginRequest struct {
|
type LoginRequest struct {
|
||||||
@@ -477,6 +478,16 @@ func DeleteUser(c *gin.Context) {
|
|||||||
|
|
||||||
func DeleteSelf(c *gin.Context) {
|
func DeleteSelf(c *gin.Context) {
|
||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
|
user, _ := model.GetUserById(id, false)
|
||||||
|
|
||||||
|
if user.Role == common.RoleRootUser {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "不能删除超级管理员账户",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
err := model.DeleteUserById(id)
|
err := model.DeleteUserById(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -11,6 +11,7 @@ require (
|
|||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/go-playground/validator/v10 v10.14.0
|
github.com/go-playground/validator/v10 v10.14.0
|
||||||
github.com/go-redis/redis/v8 v8.11.5
|
github.com/go-redis/redis/v8 v8.11.5
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.3.0
|
||||||
github.com/pkoukk/tiktoken-go v0.1.1
|
github.com/pkoukk/tiktoken-go v0.1.1
|
||||||
golang.org/x/crypto v0.9.0
|
golang.org/x/crypto v0.9.0
|
||||||
@@ -20,7 +21,6 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff // indirect
|
|
||||||
github.com/bytedance/sonic v1.9.1 // indirect
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
@@ -32,7 +32,6 @@ require (
|
|||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-sql-driver/mysql v1.6.0 // indirect
|
github.com/go-sql-driver/mysql v1.6.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
github.com/gomodule/redigo v2.0.0+incompatible // indirect
|
|
||||||
github.com/gorilla/context v1.1.1 // indirect
|
github.com/gorilla/context v1.1.1 // indirect
|
||||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||||
github.com/gorilla/sessions v1.2.1 // indirect
|
github.com/gorilla/sessions v1.2.1 // indirect
|
||||||
|
|||||||
7
go.sum
7
go.sum
@@ -1,5 +1,3 @@
|
|||||||
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff h1:RmdPFa+slIr4SCBg4st/l/vZWVe9QJKMXGO60Bxbe04=
|
|
||||||
github.com/boj/redistore v0.0.0-20180917114910-cd5dcc76aeff/go.mod h1:+RTT1BOk5P97fT2CiHkbFQwkK3mjsFAP6zCYV2aXtjw=
|
|
||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
@@ -54,10 +52,10 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB
|
|||||||
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||||
github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0=
|
|
||||||
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
|
|
||||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
@@ -67,7 +65,6 @@ github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8
|
|||||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||||
github.com/gorilla/sessions v1.1.1/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
|
|
||||||
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
|
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
|
||||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
|
|||||||
1
main.go
1
main.go
@@ -54,6 +54,7 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
|
common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
|
||||||
}
|
}
|
||||||
|
common.SyncFrequency = frequency
|
||||||
go model.SyncOptions(frequency)
|
go model.SyncOptions(frequency)
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
go model.SyncChannelCache(frequency)
|
go model.SyncChannelCache(frequency)
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := "无可用渠道"
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var (
|
||||||
TokenCacheSeconds = 60 * 60
|
TokenCacheSeconds = common.SyncFrequency
|
||||||
UserId2GroupCacheSeconds = 60 * 60
|
UserId2GroupCacheSeconds = common.SyncFrequency
|
||||||
UserId2QuotaCacheSeconds = 10 * 60
|
UserId2QuotaCacheSeconds = common.SyncFrequency
|
||||||
UserId2StatusCacheSeconds = 60 * 60
|
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||||
)
|
)
|
||||||
|
|
||||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||||
@@ -35,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second)
|
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set token error: " + err.Error())
|
common.SysError("Redis set token error: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -55,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second)
|
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set user group error: " + err.Error())
|
common.SysError("Redis set user group error: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -73,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
|
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set user quota error: " + err.Error())
|
common.SysError("Redis set user quota error: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -91,7 +91,7 @@ func CacheUpdateUserQuota(id int) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
|
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,7 +106,7 @@ func CacheIsUserEnabled(userId int) bool {
|
|||||||
status = common.UserStatusEnabled
|
status = common.UserStatusEnabled
|
||||||
}
|
}
|
||||||
enabled = fmt.Sprintf("%d", status)
|
enabled = fmt.Sprintf("%d", status)
|
||||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second)
|
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("Redis set user enabled error: " + err.Error())
|
common.SysError("Redis set user enabled error: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,20 +51,21 @@ func Redeem(key string, userId int) (quota int, err error) {
|
|||||||
redemption := &Redemption{}
|
redemption := &Redemption{}
|
||||||
|
|
||||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||||
err := DB.Where("`key` = ?", key).First(redemption).Error
|
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("无效的兑换码")
|
return errors.New("无效的兑换码")
|
||||||
}
|
}
|
||||||
if redemption.Status != common.RedemptionCodeStatusEnabled {
|
if redemption.Status != common.RedemptionCodeStatusEnabled {
|
||||||
return errors.New("该兑换码已被使用")
|
return errors.New("该兑换码已被使用")
|
||||||
}
|
}
|
||||||
err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
|
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
redemption.RedeemedTime = common.GetTimestamp()
|
redemption.RedeemedTime = common.GetTimestamp()
|
||||||
redemption.Status = common.RedemptionCodeStatusUsed
|
redemption.Status = common.RedemptionCodeStatusUsed
|
||||||
return redemption.SelectUpdate()
|
err = tx.Save(redemption).Error
|
||||||
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.New("兑换失败," + err.Error())
|
return 0, errors.New("兑换失败," + err.Error())
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
{
|
{
|
||||||
selfRoute.GET("/self", controller.GetSelf)
|
selfRoute.GET("/self", controller.GetSelf)
|
||||||
selfRoute.PUT("/self", controller.UpdateSelf)
|
selfRoute.PUT("/self", controller.UpdateSelf)
|
||||||
selfRoute.DELETE("/self", controller.DeleteSelf)
|
selfRoute.DELETE("/self", middleware.TurnstileCheck(), controller.DeleteSelf)
|
||||||
selfRoute.GET("/token", controller.GenerateAccessToken)
|
selfRoute.GET("/token", controller.GenerateAccessToken)
|
||||||
selfRoute.GET("/aff", controller.GetAffCode)
|
selfRoute.GET("/aff", controller.GetAffCode)
|
||||||
selfRoute.POST("/topup", controller.TopUp)
|
selfRoute.POST("/topup", controller.TopUp)
|
||||||
|
|||||||
@@ -1,36 +1,25 @@
|
|||||||
import React, { useContext, useEffect, useState } from 'react';
|
import React, { useContext, useEffect, useState } from 'react';
|
||||||
import {
|
import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react';
|
||||||
Button,
|
|
||||||
Divider,
|
|
||||||
Form,
|
|
||||||
Grid,
|
|
||||||
Header,
|
|
||||||
Image,
|
|
||||||
Message,
|
|
||||||
Modal,
|
|
||||||
Segment,
|
|
||||||
} from 'semantic-ui-react';
|
|
||||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||||
import { UserContext } from '../context/User';
|
import { UserContext } from '../context/User';
|
||||||
import { API, getLogo, showError, showSuccess, showInfo } from '../helpers';
|
import { API, getLogo, showError, showSuccess } from '../helpers';
|
||||||
|
|
||||||
const LoginForm = () => {
|
const LoginForm = () => {
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
username: '',
|
username: '',
|
||||||
password: '',
|
password: '',
|
||||||
wechat_verification_code: '',
|
wechat_verification_code: ''
|
||||||
});
|
});
|
||||||
const [searchParams, setSearchParams] = useSearchParams();
|
const [searchParams, setSearchParams] = useSearchParams();
|
||||||
const [submitted, setSubmitted] = useState(false);
|
const [submitted, setSubmitted] = useState(false);
|
||||||
const { username, password } = inputs;
|
const { username, password } = inputs;
|
||||||
const [userState, userDispatch] = useContext(UserContext);
|
const [userState, userDispatch] = useContext(UserContext);
|
||||||
let navigate = useNavigate();
|
let navigate = useNavigate();
|
||||||
|
|
||||||
const [status, setStatus] = useState({});
|
const [status, setStatus] = useState({});
|
||||||
const logo = getLogo();
|
const logo = getLogo();
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (searchParams.get("expired")) {
|
if (searchParams.get('expired')) {
|
||||||
showError('未登录或登录已过期,请重新登录!');
|
showError('未登录或登录已过期,请重新登录!');
|
||||||
}
|
}
|
||||||
let status = localStorage.getItem('status');
|
let status = localStorage.getItem('status');
|
||||||
@@ -78,7 +67,7 @@ const LoginForm = () => {
|
|||||||
if (username && password) {
|
if (username && password) {
|
||||||
const res = await API.post(`/api/user/login`, {
|
const res = await API.post(`/api/user/login`, {
|
||||||
username,
|
username,
|
||||||
password,
|
password
|
||||||
});
|
});
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
@@ -93,44 +82,44 @@ const LoginForm = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid textAlign="center" style={{ marginTop: '48px' }}>
|
<Grid textAlign='center' style={{ marginTop: '48px' }}>
|
||||||
<Grid.Column style={{ maxWidth: 450 }}>
|
<Grid.Column style={{ maxWidth: 450 }}>
|
||||||
<Header as="h2" color="" textAlign="center">
|
<Header as='h2' color='' textAlign='center'>
|
||||||
<Image src={logo} /> 用户登录
|
<Image src={logo} /> 用户登录
|
||||||
</Header>
|
</Header>
|
||||||
<Form size="large">
|
<Form size='large'>
|
||||||
<Segment>
|
<Segment>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
fluid
|
fluid
|
||||||
icon="user"
|
icon='user'
|
||||||
iconPosition="left"
|
iconPosition='left'
|
||||||
placeholder="用户名"
|
placeholder='用户名'
|
||||||
name="username"
|
name='username'
|
||||||
value={username}
|
value={username}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
/>
|
/>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
fluid
|
fluid
|
||||||
icon="lock"
|
icon='lock'
|
||||||
iconPosition="left"
|
iconPosition='left'
|
||||||
placeholder="密码"
|
placeholder='密码'
|
||||||
name="password"
|
name='password'
|
||||||
type="password"
|
type='password'
|
||||||
value={password}
|
value={password}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
/>
|
/>
|
||||||
<Button color="" fluid size="large" onClick={handleSubmit}>
|
<Button color='green' fluid size='large' onClick={handleSubmit}>
|
||||||
登录
|
登录
|
||||||
</Button>
|
</Button>
|
||||||
</Segment>
|
</Segment>
|
||||||
</Form>
|
</Form>
|
||||||
<Message>
|
<Message>
|
||||||
忘记密码?
|
忘记密码?
|
||||||
<Link to="/reset" className="btn btn-link">
|
<Link to='/reset' className='btn btn-link'>
|
||||||
点击重置
|
点击重置
|
||||||
</Link>
|
</Link>
|
||||||
; 没有账户?
|
; 没有账户?
|
||||||
<Link to="/register" className="btn btn-link">
|
<Link to='/register' className='btn btn-link'>
|
||||||
点击注册
|
点击注册
|
||||||
</Link>
|
</Link>
|
||||||
</Message>
|
</Message>
|
||||||
@@ -140,8 +129,8 @@ const LoginForm = () => {
|
|||||||
{status.github_oauth ? (
|
{status.github_oauth ? (
|
||||||
<Button
|
<Button
|
||||||
circular
|
circular
|
||||||
color="black"
|
color='black'
|
||||||
icon="github"
|
icon='github'
|
||||||
onClick={onGitHubOAuthClicked}
|
onClick={onGitHubOAuthClicked}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
@@ -150,8 +139,8 @@ const LoginForm = () => {
|
|||||||
{status.wechat_login ? (
|
{status.wechat_login ? (
|
||||||
<Button
|
<Button
|
||||||
circular
|
circular
|
||||||
color="green"
|
color='green'
|
||||||
icon="wechat"
|
icon='wechat'
|
||||||
onClick={onWeChatLoginClicked}
|
onClick={onWeChatLoginClicked}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
@@ -175,18 +164,18 @@ const LoginForm = () => {
|
|||||||
微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)
|
微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<Form size="large">
|
<Form size='large'>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
fluid
|
fluid
|
||||||
placeholder="验证码"
|
placeholder='验证码'
|
||||||
name="wechat_verification_code"
|
name='wechat_verification_code'
|
||||||
value={inputs.wechat_verification_code}
|
value={inputs.wechat_verification_code}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
/>
|
/>
|
||||||
<Button
|
<Button
|
||||||
color=""
|
color=''
|
||||||
fluid
|
fluid
|
||||||
size="large"
|
size='large'
|
||||||
onClick={onSubmitWeChatVerificationCode}
|
onClick={onSubmitWeChatVerificationCode}
|
||||||
>
|
>
|
||||||
登录
|
登录
|
||||||
|
|||||||
@@ -12,6 +12,11 @@ const PasswordResetConfirm = () => {
|
|||||||
|
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
|
|
||||||
|
const [disableButton, setDisableButton] = useState(false);
|
||||||
|
const [countdown, setCountdown] = useState(30);
|
||||||
|
|
||||||
|
const [newPassword, setNewPassword] = useState('');
|
||||||
|
|
||||||
const [searchParams, setSearchParams] = useSearchParams();
|
const [searchParams, setSearchParams] = useSearchParams();
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let token = searchParams.get('token');
|
let token = searchParams.get('token');
|
||||||
@@ -22,7 +27,21 @@ const PasswordResetConfirm = () => {
|
|||||||
});
|
});
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let countdownInterval = null;
|
||||||
|
if (disableButton && countdown > 0) {
|
||||||
|
countdownInterval = setInterval(() => {
|
||||||
|
setCountdown(countdown - 1);
|
||||||
|
}, 1000);
|
||||||
|
} else if (countdown === 0) {
|
||||||
|
setDisableButton(false);
|
||||||
|
setCountdown(30);
|
||||||
|
}
|
||||||
|
return () => clearInterval(countdownInterval);
|
||||||
|
}, [disableButton, countdown]);
|
||||||
|
|
||||||
async function handleSubmit(e) {
|
async function handleSubmit(e) {
|
||||||
|
setDisableButton(true);
|
||||||
if (!email) return;
|
if (!email) return;
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
const res = await API.post(`/api/user/reset`, {
|
const res = await API.post(`/api/user/reset`, {
|
||||||
@@ -32,14 +51,15 @@ const PasswordResetConfirm = () => {
|
|||||||
const { success, message } = res.data;
|
const { success, message } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
let password = res.data.data;
|
let password = res.data.data;
|
||||||
|
setNewPassword(password);
|
||||||
await copy(password);
|
await copy(password);
|
||||||
showNotice(`密码已重置并已复制到剪贴板:${password}`);
|
showNotice(`新密码已复制到剪贴板:${password}`);
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
}
|
}
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid textAlign='center' style={{ marginTop: '48px' }}>
|
<Grid textAlign='center' style={{ marginTop: '48px' }}>
|
||||||
<Grid.Column style={{ maxWidth: 450 }}>
|
<Grid.Column style={{ maxWidth: 450 }}>
|
||||||
@@ -57,20 +77,37 @@ const PasswordResetConfirm = () => {
|
|||||||
value={email}
|
value={email}
|
||||||
readOnly
|
readOnly
|
||||||
/>
|
/>
|
||||||
|
{newPassword && (
|
||||||
|
<Form.Input
|
||||||
|
fluid
|
||||||
|
icon='lock'
|
||||||
|
iconPosition='left'
|
||||||
|
placeholder='新密码'
|
||||||
|
name='newPassword'
|
||||||
|
value={newPassword}
|
||||||
|
readOnly
|
||||||
|
onClick={(e) => {
|
||||||
|
e.target.select();
|
||||||
|
navigator.clipboard.writeText(newPassword);
|
||||||
|
showNotice(`密码已复制到剪贴板:${newPassword}`);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
<Button
|
<Button
|
||||||
color=''
|
color='green'
|
||||||
fluid
|
fluid
|
||||||
size='large'
|
size='large'
|
||||||
onClick={handleSubmit}
|
onClick={handleSubmit}
|
||||||
loading={loading}
|
loading={loading}
|
||||||
|
disabled={disableButton}
|
||||||
>
|
>
|
||||||
提交
|
{disableButton ? `密码重置完成` : '提交'}
|
||||||
</Button>
|
</Button>
|
||||||
</Segment>
|
</Segment>
|
||||||
</Form>
|
</Form>
|
||||||
</Grid.Column>
|
</Grid.Column>
|
||||||
</Grid>
|
</Grid>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default PasswordResetConfirm;
|
export default PasswordResetConfirm;
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import Turnstile from 'react-turnstile';
|
|||||||
|
|
||||||
const PasswordResetForm = () => {
|
const PasswordResetForm = () => {
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
email: '',
|
email: ''
|
||||||
});
|
});
|
||||||
const { email } = inputs;
|
const { email } = inputs;
|
||||||
|
|
||||||
@@ -13,24 +13,29 @@ const PasswordResetForm = () => {
|
|||||||
const [turnstileEnabled, setTurnstileEnabled] = useState(false);
|
const [turnstileEnabled, setTurnstileEnabled] = useState(false);
|
||||||
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
|
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
|
||||||
const [turnstileToken, setTurnstileToken] = useState('');
|
const [turnstileToken, setTurnstileToken] = useState('');
|
||||||
|
const [disableButton, setDisableButton] = useState(false);
|
||||||
|
const [countdown, setCountdown] = useState(30);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let status = localStorage.getItem('status');
|
let countdownInterval = null;
|
||||||
if (status) {
|
if (disableButton && countdown > 0) {
|
||||||
status = JSON.parse(status);
|
countdownInterval = setInterval(() => {
|
||||||
if (status.turnstile_check) {
|
setCountdown(countdown - 1);
|
||||||
setTurnstileEnabled(true);
|
}, 1000);
|
||||||
setTurnstileSiteKey(status.turnstile_site_key);
|
} else if (countdown === 0) {
|
||||||
}
|
setDisableButton(false);
|
||||||
|
setCountdown(30);
|
||||||
}
|
}
|
||||||
}, []);
|
return () => clearInterval(countdownInterval);
|
||||||
|
}, [disableButton, countdown]);
|
||||||
|
|
||||||
function handleChange(e) {
|
function handleChange(e) {
|
||||||
const { name, value } = e.target;
|
const { name, value } = e.target;
|
||||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
setInputs(inputs => ({ ...inputs, [name]: value }));
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleSubmit(e) {
|
async function handleSubmit(e) {
|
||||||
|
setDisableButton(true);
|
||||||
if (!email) return;
|
if (!email) return;
|
||||||
if (turnstileEnabled && turnstileToken === '') {
|
if (turnstileEnabled && turnstileToken === '') {
|
||||||
showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!');
|
showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!');
|
||||||
@@ -78,13 +83,14 @@ const PasswordResetForm = () => {
|
|||||||
<></>
|
<></>
|
||||||
)}
|
)}
|
||||||
<Button
|
<Button
|
||||||
color=''
|
color='green'
|
||||||
fluid
|
fluid
|
||||||
size='large'
|
size='large'
|
||||||
onClick={handleSubmit}
|
onClick={handleSubmit}
|
||||||
loading={loading}
|
loading={loading}
|
||||||
|
disabled={disableButton}
|
||||||
>
|
>
|
||||||
提交
|
{disableButton ? `重试 (${countdown})` : '提交'}
|
||||||
</Button>
|
</Button>
|
||||||
</Segment>
|
</Segment>
|
||||||
</Form>
|
</Form>
|
||||||
|
|||||||
@@ -1,22 +1,30 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useContext, useEffect, useState } from 'react';
|
||||||
import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react';
|
import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react';
|
||||||
import { Link } from 'react-router-dom';
|
import { Link, useNavigate } from 'react-router-dom';
|
||||||
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
|
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
|
||||||
import Turnstile from 'react-turnstile';
|
import Turnstile from 'react-turnstile';
|
||||||
|
import { UserContext } from '../context/User';
|
||||||
|
|
||||||
const PersonalSetting = () => {
|
const PersonalSetting = () => {
|
||||||
|
const [userState, userDispatch] = useContext(UserContext);
|
||||||
|
let navigate = useNavigate();
|
||||||
|
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
wechat_verification_code: '',
|
wechat_verification_code: '',
|
||||||
email_verification_code: '',
|
email_verification_code: '',
|
||||||
email: '',
|
email: '',
|
||||||
|
self_account_deletion_confirmation: ''
|
||||||
});
|
});
|
||||||
const [status, setStatus] = useState({});
|
const [status, setStatus] = useState({});
|
||||||
const [showWeChatBindModal, setShowWeChatBindModal] = useState(false);
|
const [showWeChatBindModal, setShowWeChatBindModal] = useState(false);
|
||||||
const [showEmailBindModal, setShowEmailBindModal] = useState(false);
|
const [showEmailBindModal, setShowEmailBindModal] = useState(false);
|
||||||
|
const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false);
|
||||||
const [turnstileEnabled, setTurnstileEnabled] = useState(false);
|
const [turnstileEnabled, setTurnstileEnabled] = useState(false);
|
||||||
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
|
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
|
||||||
const [turnstileToken, setTurnstileToken] = useState('');
|
const [turnstileToken, setTurnstileToken] = useState('');
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [disableButton, setDisableButton] = useState(false);
|
||||||
|
const [countdown, setCountdown] = useState(30);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let status = localStorage.getItem('status');
|
let status = localStorage.getItem('status');
|
||||||
@@ -30,6 +38,19 @@ const PersonalSetting = () => {
|
|||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let countdownInterval = null;
|
||||||
|
if (disableButton && countdown > 0) {
|
||||||
|
countdownInterval = setInterval(() => {
|
||||||
|
setCountdown(countdown - 1);
|
||||||
|
}, 1000);
|
||||||
|
} else if (countdown === 0) {
|
||||||
|
setDisableButton(false);
|
||||||
|
setCountdown(30);
|
||||||
|
}
|
||||||
|
return () => clearInterval(countdownInterval); // Clean up on unmount
|
||||||
|
}, [disableButton, countdown]);
|
||||||
|
|
||||||
const handleInputChange = (e, { name, value }) => {
|
const handleInputChange = (e, { name, value }) => {
|
||||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||||
};
|
};
|
||||||
@@ -57,6 +78,26 @@ const PersonalSetting = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const deleteAccount = async () => {
|
||||||
|
if (inputs.self_account_deletion_confirmation !== userState.user.username) {
|
||||||
|
showError('请输入你的账户名以确认删除!');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const res = await API.delete('/api/user/self');
|
||||||
|
const { success, message } = res.data;
|
||||||
|
|
||||||
|
if (success) {
|
||||||
|
showSuccess('账户已删除!');
|
||||||
|
await API.get('/api/user/logout');
|
||||||
|
userDispatch({ type: 'logout' });
|
||||||
|
localStorage.removeItem('user');
|
||||||
|
navigate('/login');
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const bindWeChat = async () => {
|
const bindWeChat = async () => {
|
||||||
if (inputs.wechat_verification_code === '') return;
|
if (inputs.wechat_verification_code === '') return;
|
||||||
const res = await API.get(
|
const res = await API.get(
|
||||||
@@ -78,6 +119,7 @@ const PersonalSetting = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const sendVerificationCode = async () => {
|
const sendVerificationCode = async () => {
|
||||||
|
setDisableButton(true);
|
||||||
if (inputs.email === '') return;
|
if (inputs.email === '') return;
|
||||||
if (turnstileEnabled && turnstileToken === '') {
|
if (turnstileEnabled && turnstileToken === '') {
|
||||||
showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!');
|
showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!');
|
||||||
@@ -123,6 +165,9 @@ const PersonalSetting = () => {
|
|||||||
</Button>
|
</Button>
|
||||||
<Button onClick={generateAccessToken}>生成系统访问令牌</Button>
|
<Button onClick={generateAccessToken}>生成系统访问令牌</Button>
|
||||||
<Button onClick={getAffLink}>复制邀请链接</Button>
|
<Button onClick={getAffLink}>复制邀请链接</Button>
|
||||||
|
<Button onClick={() => {
|
||||||
|
setShowAccountDeleteModal(true);
|
||||||
|
}}>删除个人账户</Button>
|
||||||
<Divider />
|
<Divider />
|
||||||
<Header as='h3'>账号绑定</Header>
|
<Header as='h3'>账号绑定</Header>
|
||||||
{
|
{
|
||||||
@@ -195,8 +240,8 @@ const PersonalSetting = () => {
|
|||||||
name='email'
|
name='email'
|
||||||
type='email'
|
type='email'
|
||||||
action={
|
action={
|
||||||
<Button onClick={sendVerificationCode} disabled={loading}>
|
<Button onClick={sendVerificationCode} disabled={disableButton || loading}>
|
||||||
获取验证码
|
{disableButton ? `重新发送(${countdown})` : '获取验证码'}
|
||||||
</Button>
|
</Button>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
@@ -230,6 +275,47 @@ const PersonalSetting = () => {
|
|||||||
</Modal.Description>
|
</Modal.Description>
|
||||||
</Modal.Content>
|
</Modal.Content>
|
||||||
</Modal>
|
</Modal>
|
||||||
|
<Modal
|
||||||
|
onClose={() => setShowAccountDeleteModal(false)}
|
||||||
|
onOpen={() => setShowAccountDeleteModal(true)}
|
||||||
|
open={showAccountDeleteModal}
|
||||||
|
size={'tiny'}
|
||||||
|
style={{ maxWidth: '450px' }}
|
||||||
|
>
|
||||||
|
<Modal.Header>确认删除自己的帐户</Modal.Header>
|
||||||
|
<Modal.Content>
|
||||||
|
<Modal.Description>
|
||||||
|
<Form size='large'>
|
||||||
|
<Form.Input
|
||||||
|
fluid
|
||||||
|
placeholder={`输入你的账户名 ${userState?.user?.username} 以确认删除`}
|
||||||
|
name='self_account_deletion_confirmation'
|
||||||
|
value={inputs.self_account_deletion_confirmation}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
/>
|
||||||
|
{turnstileEnabled ? (
|
||||||
|
<Turnstile
|
||||||
|
sitekey={turnstileSiteKey}
|
||||||
|
onVerify={(token) => {
|
||||||
|
setTurnstileToken(token);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<></>
|
||||||
|
)}
|
||||||
|
<Button
|
||||||
|
color='red'
|
||||||
|
fluid
|
||||||
|
size='large'
|
||||||
|
onClick={deleteAccount}
|
||||||
|
loading={loading}
|
||||||
|
>
|
||||||
|
删除
|
||||||
|
</Button>
|
||||||
|
</Form>
|
||||||
|
</Modal.Description>
|
||||||
|
</Modal.Content>
|
||||||
|
</Modal>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,13 +1,5 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import {
|
import { Button, Form, Grid, Header, Image, Message, Segment } from 'semantic-ui-react';
|
||||||
Button,
|
|
||||||
Form,
|
|
||||||
Grid,
|
|
||||||
Header,
|
|
||||||
Image,
|
|
||||||
Message,
|
|
||||||
Segment,
|
|
||||||
} from 'semantic-ui-react';
|
|
||||||
import { Link, useNavigate } from 'react-router-dom';
|
import { Link, useNavigate } from 'react-router-dom';
|
||||||
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
|
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
|
||||||
import Turnstile from 'react-turnstile';
|
import Turnstile from 'react-turnstile';
|
||||||
@@ -18,7 +10,7 @@ const RegisterForm = () => {
|
|||||||
password: '',
|
password: '',
|
||||||
password2: '',
|
password2: '',
|
||||||
email: '',
|
email: '',
|
||||||
verification_code: '',
|
verification_code: ''
|
||||||
});
|
});
|
||||||
const { username, password, password2 } = inputs;
|
const { username, password, password2 } = inputs;
|
||||||
const [showEmailVerification, setShowEmailVerification] = useState(false);
|
const [showEmailVerification, setShowEmailVerification] = useState(false);
|
||||||
@@ -178,7 +170,7 @@ const RegisterForm = () => {
|
|||||||
<></>
|
<></>
|
||||||
)}
|
)}
|
||||||
<Button
|
<Button
|
||||||
color=''
|
color='green'
|
||||||
fluid
|
fluid
|
||||||
size='large'
|
size='large'
|
||||||
onClick={handleSubmit}
|
onClick={handleSubmit}
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
export const CHANNEL_OPTIONS = [
|
export const CHANNEL_OPTIONS = [
|
||||||
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
|
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
|
||||||
{ key: 14, text: 'Anthropic', value: 14, color: 'black' },
|
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
|
||||||
{ key: 8, text: '自定义', value: 8, color: 'pink' },
|
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
|
||||||
{ key: 3, text: 'Azure', value: 3, color: 'olive' },
|
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
|
||||||
{ key: 2, text: 'API2D', value: 2, color: 'blue' },
|
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
|
||||||
{ key: 4, text: 'CloseAI', value: 4, color: 'teal' },
|
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
||||||
{ key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
{ key: 6, text: 'OpenAI Max', value: 6, color: 'violet' },
|
{ key: 2, text: '代理:API2D', value: 2, color: 'blue' },
|
||||||
{ key: 7, text: 'OhMyGPT', value: 7, color: 'purple' },
|
{ key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
|
||||||
{ key: 9, text: 'AI.LS', value: 9, color: 'yellow' },
|
{ key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' },
|
||||||
{ key: 10, text: 'AI Proxy', value: 10, color: 'purple' },
|
{ key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' },
|
||||||
{ key: 12, text: 'API2GPT', value: 12, color: 'blue' },
|
{ key: 4, text: '代理:CloseAI', value: 4, color: 'teal' },
|
||||||
{ key: 13, text: 'AIGC2D', value: 13, color: 'purple' }
|
{ key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' },
|
||||||
|
{ key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' },
|
||||||
|
{ key: 12, text: '代理:API2GPT', value: 12, color: 'blue' },
|
||||||
|
{ key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' }
|
||||||
];
|
];
|
||||||
@@ -46,9 +46,7 @@ const About = () => {
|
|||||||
about.startsWith('https://') ? <iframe
|
about.startsWith('https://') ? <iframe
|
||||||
src={about}
|
src={about}
|
||||||
style={{ width: '100%', height: '100vh', border: 'none' }}
|
style={{ width: '100%', height: '100vh', border: 'none' }}
|
||||||
/> : <Segment>
|
/> : <div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div>
|
||||||
<div style={{ fontSize: 'larger' }} dangerouslySetInnerHTML={{ __html: about }}></div>
|
|
||||||
</Segment>
|
|
||||||
}
|
}
|
||||||
</>
|
</>
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -215,26 +215,12 @@ const EditChannel = () => {
|
|||||||
</Form.Field>
|
</Form.Field>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
|
||||||
inputs.type !== 3 && inputs.type !== 8 && (
|
|
||||||
<Form.Field>
|
|
||||||
<Form.Input
|
|
||||||
label='镜像'
|
|
||||||
name='base_url'
|
|
||||||
placeholder={'此项可选,输入镜像站地址,格式为:https://domain.com'}
|
|
||||||
onChange={handleInputChange}
|
|
||||||
value={inputs.base_url}
|
|
||||||
autoComplete='new-password'
|
|
||||||
/>
|
|
||||||
</Form.Field>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='名称'
|
label='名称'
|
||||||
required
|
required
|
||||||
name='name'
|
name='name'
|
||||||
placeholder={'请输入名称'}
|
placeholder={'请为渠道命名'}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={inputs.name}
|
value={inputs.name}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
@@ -243,7 +229,7 @@ const EditChannel = () => {
|
|||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Dropdown
|
<Form.Dropdown
|
||||||
label='分组'
|
label='分组'
|
||||||
placeholder={'请选择分组'}
|
placeholder={'请选择可以使用该渠道的分组'}
|
||||||
name='groups'
|
name='groups'
|
||||||
required
|
required
|
||||||
fluid
|
fluid
|
||||||
@@ -260,7 +246,7 @@ const EditChannel = () => {
|
|||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Dropdown
|
<Form.Dropdown
|
||||||
label='模型'
|
label='模型'
|
||||||
placeholder={'请选择该通道所支持的模型'}
|
placeholder={'请选择该渠道所支持的模型'}
|
||||||
name='models'
|
name='models'
|
||||||
required
|
required
|
||||||
fluid
|
fluid
|
||||||
@@ -312,7 +298,7 @@ const EditChannel = () => {
|
|||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.TextArea
|
<Form.TextArea
|
||||||
label='模型映射'
|
label='模型映射'
|
||||||
placeholder={`此项可选,为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
|
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
|
||||||
name='model_mapping'
|
name='model_mapping'
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={inputs.model_mapping}
|
value={inputs.model_mapping}
|
||||||
@@ -337,7 +323,7 @@ const EditChannel = () => {
|
|||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
required
|
required
|
||||||
placeholder={'请输入密钥'}
|
placeholder={inputs.type === 15 ? "请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次" : '请输入渠道对应的鉴权密钥'}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={inputs.key}
|
value={inputs.key}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
@@ -354,6 +340,20 @@ const EditChannel = () => {
|
|||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
inputs.type !== 3 && inputs.type !== 8 && (
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Input
|
||||||
|
label='镜像'
|
||||||
|
name='base_url'
|
||||||
|
placeholder={'此项可选,用于通过镜像站来进行 API 调用,请输入镜像站地址,格式为:https://domain.com'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.base_url}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button>
|
<Button type={isEdit ? "button" : "submit"} positive onClick={submit}>提交</Button>
|
||||||
</Form>
|
</Form>
|
||||||
</Segment>
|
</Segment>
|
||||||
|
|||||||
@@ -7,24 +7,32 @@ const TopUp = () => {
|
|||||||
const [redemptionCode, setRedemptionCode] = useState('');
|
const [redemptionCode, setRedemptionCode] = useState('');
|
||||||
const [topUpLink, setTopUpLink] = useState('');
|
const [topUpLink, setTopUpLink] = useState('');
|
||||||
const [userQuota, setUserQuota] = useState(0);
|
const [userQuota, setUserQuota] = useState(0);
|
||||||
|
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||||
|
|
||||||
const topUp = async () => {
|
const topUp = async () => {
|
||||||
if (redemptionCode === '') {
|
if (redemptionCode === '') {
|
||||||
showInfo('请输入充值码!')
|
showInfo('请输入充值码!')
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const res = await API.post('/api/user/topup', {
|
setIsSubmitting(true);
|
||||||
key: redemptionCode
|
try {
|
||||||
});
|
const res = await API.post('/api/user/topup', {
|
||||||
const { success, message, data } = res.data;
|
key: redemptionCode
|
||||||
if (success) {
|
|
||||||
showSuccess('充值成功!');
|
|
||||||
setUserQuota((quota) => {
|
|
||||||
return quota + data;
|
|
||||||
});
|
});
|
||||||
setRedemptionCode('');
|
const { success, message, data } = res.data;
|
||||||
} else {
|
if (success) {
|
||||||
showError(message);
|
showSuccess('充值成功!');
|
||||||
|
setUserQuota((quota) => {
|
||||||
|
return quota + data;
|
||||||
|
});
|
||||||
|
setRedemptionCode('');
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
showError('请求失败');
|
||||||
|
} finally {
|
||||||
|
setIsSubmitting(false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -74,8 +82,8 @@ const TopUp = () => {
|
|||||||
<Button color='green' onClick={openTopUpLink}>
|
<Button color='green' onClick={openTopUpLink}>
|
||||||
获取兑换码
|
获取兑换码
|
||||||
</Button>
|
</Button>
|
||||||
<Button color='yellow' onClick={topUp}>
|
<Button color='yellow' onClick={topUp} disabled={isSubmitting}>
|
||||||
充值
|
{isSubmitting ? '兑换中...' : '兑换'}
|
||||||
</Button>
|
</Button>
|
||||||
</Form>
|
</Form>
|
||||||
</Grid.Column>
|
</Grid.Column>
|
||||||
@@ -92,5 +100,4 @@ const TopUp = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export default TopUp;
|
||||||
export default TopUp;
|
|
||||||
Reference in New Issue
Block a user