Compare commits

...

24 Commits

Author SHA1 Message Date
JustSong
cbd62011b8 chore: add database migration prompt 2023-10-02 12:13:30 +08:00
JustSong
4701897e2e chore: sync database indices with pro version 2023-10-02 12:11:02 +08:00
JustSong
0f6c132a80 docs: update readme 2023-10-01 13:26:42 +08:00
JustSong
3cac45dc85 docs: update readme 2023-10-01 13:16:25 +08:00
JustSong
47c08c72ce fix: check user quota when pre-consume quota 2023-10-01 12:49:40 +08:00
JustSong
53b2cace0b chore: sync model schema 2023-09-29 18:13:57 +08:00
JustSong
f0fc991b44 chore: remind user to change default password (close #516) 2023-09-29 18:07:20 +08:00
JustSong
594f06e7b0 perf: lazy initialization for token encoders (close #566) 2023-09-29 17:56:11 +08:00
JustSong
197d1d7a9d docs: update readme 2023-09-29 17:49:47 +08:00
JustSong
f9b748c2ca chore: add MEMORY_CACHE_ENABLED env variable 2023-09-29 11:38:27 +08:00
JustSong
fd98463611 chore: update ali's model name 2023-09-23 22:57:59 +08:00
JustSong
f5a1cd3463 feat: add support for gpt-3.5-turbo-instruct (close #545) 2023-09-23 22:37:11 +08:00
igophper
8651451e53 fix: sum null to 0 (#541)
Co-authored-by: igophper <admin@jialilgu.cn>
2023-09-19 22:39:54 +08:00
JustSong
1c5bb97a42 fix: fix gorm expression
Co-authored-by: 初音控灬 <xyfacai@gmail.com>
2023-09-18 23:11:37 +08:00
JustSong
de868e4e4e fix: fix gorm expression
Co-authored-by: 初音控灬 <xyfacai@gmail.com>
2023-09-18 23:07:59 +08:00
JustSong
1d258cc898 fix: add default value for base url 2023-09-18 22:49:05 +08:00
JustSong
37e09d764c fix: fix random selection is not working when directly using database 2023-09-18 22:39:10 +08:00
JustSong
159b9e3369 fix: fix unable to set zero value for base url & model mapping 2023-09-18 22:07:17 +08:00
JustSong
92001986db Merge branch 'main' of https://github.com/songquanpeng/one-api 2023-09-18 21:44:36 +08:00
JustSong
a5647b1ea7 fix: fix priority not updated & random choice not working 2023-09-18 21:43:45 +08:00
Redreamality
215e54fc96 docs: update readme (#502)
* Update README.md

* docs: update README

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2023-09-17 21:39:21 +08:00
Xyfacai
ecf8a6d875 feat: supprt channel priority now & record channel id in log (#484)
* feat: 支持设置渠道优先级 & 日志中显示使用的渠道ID

* fix: 设置渠道优先级未更新 ability

* chore: update implementation

---------

Co-authored-by: Xiangyuan Liu <xiangyuan.liu@ui.com>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2023-09-17 19:18:16 +08:00
igophper
24df3e5f62 feat: support non-stream mode for xunfei (#498)
* feat:xunfei suport none stream

* fix:join content ignore seq

---------

Co-authored-by: igophper <admin@jialilgu.cn>
2023-09-17 18:16:12 +08:00
JustSong
12ef9679a7 fix: fix url not passing when using custom chat_link 2023-09-17 17:19:12 +08:00
26 changed files with 393 additions and 202 deletions

View File

@@ -59,6 +59,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
> **Warning**
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
> **Warning**
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`
## 功能
1. 支持多种大模型:
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
@@ -103,11 +106,17 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
## 部署
### 基于 Docker 进行部署
部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api`
```shell
# 使用 SQLite 的部署命令:
docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。
# 例如:
docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
```
其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
@@ -236,7 +245,7 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
<summary><strong>部署到 Zeabur</strong></summary>
<div>
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
1. 首先 fork 一份代码。
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
@@ -251,6 +260,17 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
</div>
</details>
<details>
<summary><strong>部署到 Render</strong></summary>
<div>
> Render 提供免费额度,绑卡后可以进一步提升额度
Render 可以直接部署 docker 镜像,不需要 fork 仓库https://dashboard.render.com
</div>
</details>
## 配置
系统本身开箱即用。
@@ -269,13 +289,20 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
注意,具体的 API Base 的格式取决于你所使用的客户端。
例如对于 OpenAI 的官方库:
```bash
OPENAI_API_KEY="sk-xxxxxx"
OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
```
```mermaid
graph LR
A(用户)
A --->|请求| B(One API)
A --->|使用 One API 分发的 key 进行请求| B(One API)
B -->|中继请求| C(OpenAI)
B -->|中继请求| D(Azure)
B -->|中继请求| E(其他下游渠道)
B -->|中继请求| E(其他 OpenAI API 格式下游渠道)
B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道)
```
可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
@@ -303,22 +330,24 @@ graph LR
+ `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步
5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
+ 例子:`SYNC_FREQUENCY=60`
6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
+ 例子:`NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5`
12. 请求频率限制:
13. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。

View File

@@ -56,6 +56,7 @@ var EmailDomainWhitelist = []string{
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
@@ -92,7 +93,7 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)

View File

@@ -24,6 +24,7 @@ var ModelRatio = map[string]float64{
"gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
@@ -50,8 +51,8 @@ var ModelRatio = map[string]float64{
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-v1": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus-v1": 1, // ¥0.014 / 1k tokens
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens

View File

@@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
}
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
@@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.BaseURL == "" {
channel.BaseURL = baseURL
if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL
}
switch channel.Type {
case common.ChannelTypeOpenAI:
if channel.BaseURL != "" {
baseURL = channel.BaseURL
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
case common.ChannelTypeAzure:
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
baseURL = channel.BaseURL
baseURL = channel.GetBaseURL()
case common.ChannelTypeCloseAI:
return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB:

View File

@@ -42,10 +42,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
} else {
if channel.BaseURL != "" {
requestURL = channel.BaseURL
if channel.GetBaseURL() != "" {
requestURL = channel.GetBaseURL()
}
requestURL += "/v1/chat/completions"
}

View File

@@ -19,7 +19,8 @@ func GetAllLogs(c *gin.Context) {
username := c.Query("username")
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -106,7 +107,8 @@ func GetLogsStat(c *gin.Context) {
tokenName := c.Query("token_name")
username := c.Query("username")
modelName := c.Query("model_name")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
channel, _ := strconv.Atoi(c.Query("channel"))
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
c.JSON(http.StatusOK, gin.H{
"success": true,
@@ -126,7 +128,8 @@ func GetLogsSelfStat(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
channel, _ := strconv.Atoi(c.Query("channel"))
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
c.JSON(http.StatusOK, gin.H{
"success": true,

View File

@@ -117,6 +117,15 @@ func init() {
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-instruct",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-instruct",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
@@ -343,21 +352,21 @@ func init() {
Parent: nil,
},
{
Id: "qwen-v1",
Id: "qwen-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-v1",
Root: "qwen-turbo",
Parent: nil,
},
{
Id: "qwen-plus-v1",
Id: "qwen-plus",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus-v1",
Root: "qwen-plus",
Parent: nil,
},
{

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -18,6 +19,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
@@ -30,6 +32,9 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
@@ -107,7 +112,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

View File

@@ -19,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
@@ -98,7 +99,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
quota := int(ratio*sizeRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden)
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
@@ -138,7 +139,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

View File

@@ -38,6 +38,7 @@ func init() {
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
@@ -203,6 +204,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
@@ -364,7 +368,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
var textResponse TextResponse
tokenName := c.GetString("token_name")
channelId := c.GetInt("channel_id")
defer func(ctx context.Context) {
// c.Writer.Flush()
@@ -397,7 +400,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
@@ -541,24 +544,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return nil
}
case APITypeXunfei:
if isStream {
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
auth := c.Request.Header.Get("Authorization")
auth = strings.TrimPrefix(auth, "Bearer ")
splits := strings.Split(auth, "|")
if len(splits) != 3 {
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
case APITypeAIProxyLibrary:
if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp)

View File

@@ -9,44 +9,53 @@ import (
"net/http"
"one-api/common"
"strconv"
"strings"
)
var stopFinishReason = "stop"
// tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model, _ := range common.ModelRatio {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
tokenEncoderMap[model] = fallbackTokenEncoder
continue
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder
} else {
tokenEncoderMap[model] = nil
}
tokenEncoderMap[model] = tokenEncoder
}
common.SysLog("token encoders initialized")
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil {
return tokenEncoder
}
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
return defaultTokenEncoder
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {

View File

@@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
},
FinishReason: stopFinishReason,
}
fullTextResponse := OpenAITextResponse{
Object: "chat.completion",
@@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
}
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
setEventStreamHeaders(c)
var usage Usage
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
}
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
if apiVersion == "" {
apiVersion = "v1.1"
common.SysLog("api_version not found, use default: " + apiVersion)
var usage Usage
var content string
var xunfeiResponse XunfeiChatResponse
stop := false
for !stop {
select {
case xunfeiResponse = <-dataChan:
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
case stop = <-stopChan:
}
}
domain := "general"
if apiVersion == "v2.1" {
domain = "generalv2"
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
c.Writer.Header().Set("Content-Type", "application/json")
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
}
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
conn, resp, err := d.Dial(authUrl, nil)
if err != nil || resp.StatusCode != 101 {
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
return nil, nil, err
}
data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data)
if err != nil {
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
return nil, nil, err
}
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
@@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
return dataChan, stopChan, nil
}
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var xunfeiResponse XunfeiChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
if apiVersion == "" {
apiVersion = "v1.1"
common.SysLog("api_version not found, use default: " + apiVersion)
}
err = json.Unmarshal(responseBody, &xunfeiResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
domain := "general"
if apiVersion == "v2.1" {
domain = "generalv2"
}
if xunfeiResponse.Header.Code != 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: xunfeiResponse.Header.Message,
Type: "xunfei_error",
Param: "",
Code: xunfeiResponse.Header.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl
}

20
main.go
View File

@@ -2,6 +2,7 @@ package main
import (
"embed"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
@@ -50,18 +51,17 @@ func main() {
// Initialize options
model.InitOptionMap()
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true
}
if common.MemoryCacheEnabled {
common.SysLog("memory cache enabled")
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
model.InitChannelCache()
}
if os.Getenv("SYNC_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
}
common.SyncFrequency = frequency
go model.SyncOptions(frequency)
if common.RedisEnabled {
go model.SyncChannelCache(frequency)
}
if common.MemoryCacheEnabled {
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))

View File

@@ -94,7 +94,7 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.IsUserEnabled(token.UserId)
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
return

View File

@@ -82,9 +82,9 @@ func Distribute() func(c *gin.Context) {
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.ModelMapping)
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.BaseURL)
c.Set("base_url", channel.GetBaseURL())
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)

View File

@@ -10,15 +10,18 @@ type Ability struct {
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
err = channelQuery.Order("RAND()").First(&ability).Error
}
if err != nil {
return nil, err
@@ -40,6 +43,7 @@ func (channel *Channel) AddAbilities() error {
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
}
abilities = append(abilities, ability)
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
"sync"
@@ -159,6 +160,17 @@ func InitChannelCache() {
}
}
}
// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetPriority() > channels[j].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
}
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelSyncLock.Unlock()
@@ -174,7 +186,7 @@ func SyncChannelCache(frequency int) {
}
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
if !common.RedisEnabled {
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
}
channelSyncLock.RLock()
@@ -183,6 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
idx := rand.Intn(len(channels))
endIdx := len(channels)
// choose by priority
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
}
}
idx := rand.Intn(endIdx)
return channels[idx], nil
}

View File

@@ -11,18 +11,19 @@ type Channel struct {
Key string `json:"key" gorm:"not null;index"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Weight int `json:"weight"`
Weight *uint `json:"weight" gorm:"default:0"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL string `json:"base_url" gorm:"column:base_url"`
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"`
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -78,6 +79,27 @@ func BatchInsertChannels(channels []Channel) error {
return nil
}
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0
}
return *channel.Priority
}
func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""
}
return *channel.BaseURL
}
func (channel *Channel) GetModelMapping() string {
if channel.ModelMapping == nil {
return ""
}
return *channel.ModelMapping
}
func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error

View File

@@ -8,17 +8,18 @@ import (
)
type Log struct {
Id int `json:"id"`
UserId int `json:"user_id"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index"`
Type int `json:"type" gorm:"index"`
Id int `json:"id;index:idx_created_at_id,priority:1"`
UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
Username string `json:"username" gorm:"index;default:''"`
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
ChannelId int `json:"channel" gorm:"index"`
}
const (
@@ -46,8 +47,8 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content))
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
@@ -62,6 +63,7 @@ func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, complet
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
ChannelId: channelId,
}
err := DB.Create(log).Error
if err != nil {
@@ -69,7 +71,7 @@ func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, complet
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB
@@ -91,6 +93,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err
}
@@ -128,8 +133,8 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
return logs, err
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
tx := DB.Table("logs").Select("sum(quota)")
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@@ -145,12 +150,15 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
}
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
tx.Where("type = ?", LogTypeConsume).Scan(&quota)
return quota
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)")
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}

View File

@@ -81,6 +81,7 @@ func InitDB() (err error) {
if !common.IsMasterNode {
return nil
}
common.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return err

View File

@@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
import { Link } from 'react-router-dom';
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
@@ -24,7 +24,7 @@ function renderType(type) {
}
type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
}
return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
}
function renderBalance(type, balance) {
@@ -96,7 +96,7 @@ const ChannelsTable = () => {
});
}, []);
const manageChannel = async (id, action, idx) => {
const manageChannel = async (id, action, idx, priority) => {
let data = { id };
let res;
switch (action) {
@@ -111,6 +111,13 @@ const ChannelsTable = () => {
data.status = 2;
res = await API.put('/api/channel/', data);
break;
case 'priority':
if (priority === '') {
return;
}
data.priority = parseInt(priority);
res = await API.put('/api/channel/', data);
break;
}
const { success, message } = res.data;
if (success) {
@@ -335,6 +342,14 @@ const ChannelsTable = () => {
>
余额
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('priority');
}}
>
优先级
</Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row>
</Table.Header>
@@ -373,6 +388,22 @@ const ChannelsTable = () => {
basic
/>
</Table.Cell>
<Table.Cell>
<Popup
trigger={<Input type="number" defaultValue={channel.priority} onBlur={(event) => {
manageChannel(
channel.id,
'priority',
idx,
event.target.value,
);
}}>
<input style={{maxWidth:'60px'}} />
</Input>}
content='渠道选择优先级,越高越优先'
basic
/>
</Table.Cell>
<Table.Cell>
<div>
<Button
@@ -441,7 +472,7 @@ const ChannelsTable = () => {
<Table.Footer>
<Table.Row>
<Table.HeaderCell colSpan='8'>
<Table.HeaderCell colSpan='9'>
<Button size='small' as={Link} to='/channel/add' loading={loading}>
添加新的渠道
</Button>

View File

@@ -2,8 +2,8 @@ import React, { useContext, useEffect, useState } from 'react';
import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react';
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getLogo, showError, showSuccess } from '../helpers';
import { getOAuthState, onGitHubOAuthClicked } from './utils';
import { API, getLogo, showError, showSuccess, showWarning } from '../helpers';
import { onGitHubOAuthClicked } from './utils';
const LoginForm = () => {
const [inputs, setInputs] = useState({
@@ -68,8 +68,14 @@ const LoginForm = () => {
if (success) {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
navigate('/');
showSuccess('登录成功!');
if (username === 'root' && password === '123456') {
navigate('/user/edit');
showSuccess('登录成功!');
showWarning('请立刻修改默认密码!');
} else {
navigate('/token');
showSuccess('登录成功!');
}
} else {
showError(message);
}
@@ -126,7 +132,7 @@ const LoginForm = () => {
circular
color='black'
icon='github'
onClick={()=>onGitHubOAuthClicked(status.github_client_id)}
onClick={() => onGitHubOAuthClicked(status.github_client_id)}
/>
) : (
<></>

View File

@@ -56,9 +56,10 @@ const LogsTable = () => {
token_name: '',
model_name: '',
start_timestamp: timestamp2string(0),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: ''
});
const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
const [stat, setStat] = useState({
quota: 0,
@@ -84,7 +85,7 @@ const LogsTable = () => {
const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
const { success, message, data } = res.data;
if (success) {
setStat(data);
@@ -109,7 +110,7 @@ const LogsTable = () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
if (isAdminUser) {
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
} else {
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
@@ -205,16 +206,9 @@ const LogsTable = () => {
</Header>
<Form>
<Form.Group>
{
isAdminUser && (
<Form.Input fluid label={'用户名称'} width={2} value={username}
placeholder={'可选值'} name='username'
onChange={handleInputChange} />
)
}
<Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
<Form.Input fluid label={'令牌名称'} width={3} value={token_name}
placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
<Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
<Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
name='model_name'
onChange={handleInputChange} />
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
@@ -225,6 +219,19 @@ const LogsTable = () => {
onChange={handleInputChange} />
<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
</Form.Group>
{
isAdminUser && <>
<Form.Group>
<Form.Input fluid label={'渠道 ID'} width={3} value={channel}
placeholder='可选值' name='channel'
onChange={handleInputChange} />
<Form.Input fluid label={'用户名称'} width={3} value={username}
placeholder={'可选值'} name='username'
onChange={handleInputChange} />
</Form.Group>
</>
}
</Form>
<Table basic compact size='small'>
<Table.Header>
@@ -238,6 +245,17 @@ const LogsTable = () => {
>
时间
</Table.HeaderCell>
{
isAdminUser && <Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('channel');
}}
width={1}
>
渠道
</Table.HeaderCell>
}
{
isAdminUser && <Table.HeaderCell
style={{ cursor: 'pointer' }}
@@ -299,16 +317,16 @@ const LogsTable = () => {
onClick={() => {
sortLog('quota');
}}
width={2}
width={1}
>
消耗额度
额度
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('content');
}}
width={isAdminUser ? 4 : 5}
width={isAdminUser ? 4 : 6}
>
详情
</Table.HeaderCell>
@@ -326,6 +344,11 @@ const LogsTable = () => {
return (
<Table.Row key={log.id}>
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
{
isAdminUser && (
<Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
)
}
{
isAdminUser && (
<Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
@@ -345,7 +368,7 @@ const LogsTable = () => {
<Table.Footer>
<Table.Row>
<Table.HeaderCell colSpan={'9'}>
<Table.HeaderCell colSpan={'10'}>
<Select
placeholder='选择明细分类'
options={LOG_OPTIONS}

View File

@@ -96,7 +96,7 @@ const TokensTable = () => {
let nextUrl;
if (nextLink) {
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
} else {
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
}

View File

@@ -67,7 +67,7 @@ const EditChannel = () => {
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
break;
case 16:
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
@@ -174,7 +174,7 @@ const EditChannel = () => {
return;
}
let localInputs = inputs;
if (localInputs.base_url.endsWith('/')) {
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
}
if (localInputs.type === 3 && localInputs.other === '') {
@@ -183,9 +183,6 @@ const EditChannel = () => {
if (localInputs.type === 18 && localInputs.other === '') {
localInputs.other = 'v2.1';
}
if (localInputs.model_mapping === '') {
localInputs.model_mapping = '{}';
}
let res;
localInputs.models = localInputs.models.join(',');
localInputs.group = localInputs.groups.join(',');

View File

@@ -102,7 +102,7 @@ const EditUser = () => {
label='密码'
name='password'
type={'password'}
placeholder={'请输入新的密码'}
placeholder={'请输入新的密码,最短 8 位'}
onChange={handleInputChange}
value={password}
autoComplete='new-password'